from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
import pylapy
from tqdm import tqdm
from traccuracy._tracking_graph import TrackingGraph
from ._base import Matcher
from ._compute_overlap import get_labels_with_overlap, graph_bbox_and_labels
if TYPE_CHECKING:
from collections.abc import Hashable
def _match_nodes(
gt: np.ndarray,
res: np.ndarray,
gt_boxes: np.ndarray | None = None,
res_boxes: np.ndarray | None = None,
gt_labels: np.ndarray | None = None,
res_labels: np.ndarray | None = None,
threshold: float = 0.5,
one_to_one: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""Identify overlapping objects according to IoU and a threshold for minimum overlap.
Args:
gt (np.ndarray): labeled frame
res (np.ndarray): labeled frame
gt_boxes (np.ndarray | None): bounding boxes for the gt frame
res_boxes (np.ndarray | None): bounding boxes for the res frame
gt_labels (np.ndarray | None): labels for the gt frame
res_labels (np.ndarray | None): labels for the res frame
threshold (optional, float): threshold value for IoU to count as same cell. Default 1.
If segmentations are identical, 1 works well.
For imperfect segmentations try 0.6-0.8 to get better matching
one_to_one (optional, bool): If True, forces the mapping to be one-to-one by running
linear assignment on the thresholded iou array. Default False.
Returns:
gtcells (np.ndarray): Array of overlapping ids in the gt frame.
rescells (np.ndarray): Array of overlapping ids in the res frame.
"""
if threshold == 0.0 and not one_to_one:
raise ValueError("Threshold of 0 is not valid unless one_to_one is True")
ious = get_labels_with_overlap(
gt,
res,
gt_boxes=gt_boxes,
res_boxes=res_boxes,
gt_labels=gt_labels,
res_labels=res_labels,
overlap="iou",
)
if not ious:
return np.array([], dtype=np.intp), np.array([], dtype=np.intp)
iou_arr = np.array(ious) # shape (N, 3)
# filter to IOUs above threshold
mask = iou_arr[:, 2] >= threshold
filtered_ious = iou_arr[mask]
if len(filtered_ious) == 0:
return np.array([], dtype=np.intp), np.array([], dtype=np.intp)
if one_to_one:
pairs = _one_to_one_assignment(filtered_ious)
else:
pairs = (filtered_ious[:, 0].astype(np.intp), filtered_ious[:, 1].astype(np.intp))
return pairs[0], pairs[1]
def _one_to_one_assignment(
ious: np.ndarray,
unmapped_cost: int = 4,
) -> tuple[np.ndarray, np.ndarray]:
"""Perform linear assignment on IoU overlaps to create a one-to-one mapping.
Builds a compact cost matrix using only the labels that appear in the overlaps,
avoiding allocation of a large dense matrix indexed by max label value.
Args:
ious: List of (gt_label, res_label, iou_value) tuples from get_labels_with_overlap,
already filtered to remove matches below threshold
unmapped_cost (float, optional): Cost of an unassigned cell.
Lower values leads to more unassigned cells. Defaults to 4.
Returns:
tuple: Tuple of two arrays containing matched gt and res label indices.
"""
gt_labels_set: set[int] = set(ious[:, 0])
res_labels_set: set[int] = set(ious[:, 1])
gt_label_list = sorted(gt_labels_set)
res_label_list = sorted(res_labels_set)
gt_idx = {label: i for i, label in enumerate(gt_label_list)}
res_idx = {label: i for i, label in enumerate(res_label_list)}
cost = np.ones((len(gt_label_list), len(res_label_list)))
gt_indices = np.array([gt_idx[idx] for idx in ious[:, 0]])
res_indices = np.array([res_idx[idx] for idx in ious[:, 1]])
cost[gt_indices, res_indices] = 1 - ious[:, 2]
cost[cost == 1] = np.inf
solver = pylapy.LapSolver(implementation="scipy", sparse_implementation="csgraph")
assignments = solver.sparse_solve(cost, eta=unmapped_cost + 1)
rows = np.array([gt_label_list[r] for r, _ in assignments], dtype=np.intp)
cols = np.array([res_label_list[c] for _, c in assignments], dtype=np.intp)
return rows, cols
def _construct_time_to_seg_id_map(
graph: TrackingGraph,
) -> dict[int, dict[Hashable, Hashable]]:
"""For each time frame in the graph, create a mapping from segmentation ids
(the ids in the segmentation array, stored in graph.label_key) to the
node ids (the ids of the TrackingGraph nodes).
Args:
graph(TrackingGraph): a tracking graph with a label_key on each node
Returns:
dict[int, dict[Hashable, Hashable]]: a dictionary from {time: {segmentation_id: node_id}}
Raises:
AssertionError: If two nodes in a time frame have the same segmentation_id
"""
time_to_seg_id_map: dict[int, dict[Hashable, Hashable]] = {}
if graph.label_key is None:
raise ValueError("No label_key provided for input TrackingGraph")
for node_id, data in graph.nodes(data=True):
time = data[graph.frame_key]
seg_id = data[graph.label_key]
seg_id_to_node_id_map = time_to_seg_id_map.get(time, {})
assert seg_id not in seg_id_to_node_id_map, (
f"Segmentation ID {seg_id} occurred twice in frame {time}."
)
seg_id_to_node_id_map[seg_id] = node_id
time_to_seg_id_map[time] = seg_id_to_node_id_map
return time_to_seg_id_map
[docs]
def match_iou(
gt: TrackingGraph, pred: TrackingGraph, threshold: float = 0.6, one_to_one: bool = False
) -> list[tuple[Hashable, Hashable]]:
"""Identifies pairs of cells between gt and pred that have iou > threshold
This can return more than one match for any node
Assumes that within a frame, each object has a unique segmentation label
and that the label is recorded on each node using label_key
Args:
gt (traccuracy.TrackingGraph): Tracking data object containing graph and segmentations
pred (traccuracy.TrackingGraph): Tracking data object containing graph and segmentations
threshold (float, optional): Minimum IoU for matching cells. Defaults to 0.6.
one_to_one (optional, bool): If True, forces the mapping to be one-to-one by running
linear assignment on the thresholded iou array. Default False.
Returns:
list[(gt_node, pred_node)]: list of tuples where each tuple contains a gt node and pred node
Raises:
ValueError: gt and pred must be a TrackingData object
ValueError: GT and pred segmentations must be the same shape
"""
if not isinstance(gt, TrackingGraph) or not isinstance(pred, TrackingGraph):
raise ValueError("Input data must be a TrackingData object with a graph and segmentations")
if gt.segmentation is None or pred.segmentation is None:
raise ValueError("TrackingGraph must contain a segmentation array for IoU matching")
if gt.segmentation.shape != pred.segmentation.shape:
raise ValueError("Segmentation shapes must match between gt and pred")
gt_seg = gt.segmentation
pred_seg = pred.segmentation
mapper: list[tuple[Hashable, Hashable]] = []
if gt.start_frame is None or gt.end_frame is None:
return mapper
gt_time_to_seg_id_map = _construct_time_to_seg_id_map(gt)
pred_time_to_seg_id_map = _construct_time_to_seg_id_map(pred)
for i, t in enumerate(
tqdm(
range(gt.start_frame, gt.end_frame),
desc="Matching frames",
)
):
gt_nodes = gt.nodes_by_frame[t]
pred_nodes = pred.nodes_by_frame[t]
gt_boxes, gt_labels = graph_bbox_and_labels(gt.graph, gt_nodes, gt.label_key)
pred_boxes, pred_labels = graph_bbox_and_labels(pred.graph, pred_nodes, pred.label_key)
matches = _match_nodes(
gt_seg[i],
pred_seg[i],
gt_boxes=gt_boxes,
res_boxes=pred_boxes,
gt_labels=gt_labels,
res_labels=pred_labels,
threshold=threshold,
one_to_one=one_to_one,
)
# Construct node id tuple for each match
for gt_seg_id, pred_seg_id in zip(*matches, strict=True):
# Skip if either seg label has no corresponding graph node
# (e.g. node removed by border_margin filtering)
gt_seg_map = gt_time_to_seg_id_map.get(t, {})
pred_seg_map = pred_time_to_seg_id_map.get(t, {})
if gt_seg_id not in gt_seg_map or pred_seg_id not in pred_seg_map:
continue
mapper.append((gt_seg_map[gt_seg_id], pred_seg_map[pred_seg_id]))
return mapper
[docs]
class IOUMatcher(Matcher):
"""Constructs a mapping between gt and pred nodes using the IoU of the segmentations
Lower values for iou_threshold will be more permissive of imperfect matches
Args:
iou_threshold (float, optional): Minimum IoU value to assign a match. Defaults to 0.6.
one_to_one (optional, bool): If True, forces the mapping to be one-to-one by running
linear assignment on the thresholded iou array. Default False.
"""
def __init__(self, iou_threshold: float = 0.6, one_to_one: bool = False):
self.iou_threshold = iou_threshold
self.one_to_one = one_to_one
# If either condition is met, matching must be one to one
if one_to_one or iou_threshold > 0.5:
self._matching_type = "one-to-one"
def _compute_mapping(
self, gt_graph: TrackingGraph, pred_graph: TrackingGraph
) -> list[tuple[Hashable, Hashable]]:
"""Computes IOU mapping for a set of graphs
Args:
gt_graph (TrackingGraph): Tracking graph object for the gt with segmentation data
pred_graph (TrackingGraph): Tracking graph object for the pred with segmentation data
Raises:
ValueError: Segmentation data must be provided for both gt and pred data
Returns:
Matched: Matched data object containing IOU mapping
"""
# Check that segmentations exist in the data
if gt_graph.segmentation is None or pred_graph.segmentation is None:
raise ValueError("Segmentation data must be provided for both gt and pred data")
mapping = match_iou(
gt_graph,
pred_graph,
threshold=self.iou_threshold,
one_to_one=self.one_to_one,
)
return mapping