from __future__ import annotations
from typing import TYPE_CHECKING, cast
import numpy as np
import pylapy
from tqdm import tqdm
from traccuracy._tracking_graph import TrackingGraph
from ._base import Matcher
from ._compute_overlap import get_labels_with_overlap, graph_bbox_and_labels
if TYPE_CHECKING:
from collections.abc import Hashable
def _match_nodes(
gt: np.ndarray,
res: np.ndarray,
gt_boxes: np.ndarray | None = None,
res_boxes: np.ndarray | None = None,
gt_labels: np.ndarray | None = None,
res_labels: np.ndarray | None = None,
threshold: float = 0.5,
one_to_one: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""Identify overlapping objects according to IoU and a threshold for minimum overlap.
QUESTION: Does this rely on sequential segmentation labels
Args:
gt (np.ndarray): labeled frame
res (np.ndarray): labeled frame
gt_boxes (np.ndarray | None): bounding boxes for the gt frame
res_boxes (np.ndarray | None): bounding boxes for the res frame
gt_labels (np.ndarray | None): labels for the gt frame
res_labels (np.ndarray | None): labels for the res frame
threshold (optional, float): threshold value for IoU to count as same cell. Default 1.
If segmentations are identical, 1 works well.
For imperfect segmentations try 0.6-0.8 to get better matching
one_to_one (optional, bool): If True, forces the mapping to be one-to-one by running
linear assignment on the thresholded iou array. Default False.
Returns:
gtcells (np.ndarray): Array of overlapping ids in the gt frame.
rescells (np.ndarray): Array of overlapping ids in the res frame.
"""
if threshold == 0.0 and not one_to_one:
raise ValueError("Threshold of 0 is not valid unless one_to_one is True")
# casting to int to avoid issue #152 (result is float with numpy<2, dtype=uint64)
iou = np.zeros((int(np.max(gt) + 1), int(np.max(res) + 1)))
ious = get_labels_with_overlap(
gt,
res,
gt_boxes=gt_boxes,
res_boxes=res_boxes,
gt_labels=gt_labels,
res_labels=res_labels,
overlap="iou",
)
for gt_label, res_label, iou_val in ious:
if iou_val >= threshold:
iou[gt_label, res_label] = iou_val
if one_to_one:
pairs = _one_to_one_assignment(iou)
else:
# np.where returns tuple[ndarray[Any, dtype[signedinteger[Any]]], ...]
# this is functionally equivalent to a normal tuple of arrays so we need
# to cast to match the return type of _one_to_one_assignment
pairs = cast("tuple[np.ndarray, np.ndarray]", np.where(iou))
gtcells, rescells = pairs[0], pairs[1]
return gtcells, rescells
def _one_to_one_assignment(
iou: np.ndarray, unmapped_cost: int = 4
) -> tuple[np.ndarray, np.ndarray]:
"""Perform linear assignment on the iou matrix to create a one-to-one
mapping
Args:
iou (np.array): Array containing thresholded iou values
unmapped_cost (float, optional): Cost of an unassigned cell.
Lower values leads to more unassigned cells. Defaults to 4.
Returns:
tuple: Tuple of two arrays, one for indices of each axis
"""
# Lap solver using scipy
solver = pylapy.LapSolver(implementation="scipy", sparse_implementation="csgraph")
# Exclude the background which is currently included in iou matrix
cost = 1 - iou[1:, 1:]
cost[cost == 1] = np.inf
# Let's keep eta = unmapped_cost + 1 for compatibility. But one could probably do
# hard thresholding instead (using hard=True) which is indeed what we want to do
# Add 1 to all indices to correct for the removed background
rows, cols = (solver.sparse_solve(cost, eta=unmapped_cost + 1) + 1).T
return rows, cols
def _construct_time_to_seg_id_map(
graph: TrackingGraph,
) -> dict[int, dict[Hashable, Hashable]]:
"""For each time frame in the graph, create a mapping from segmentation ids
(the ids in the segmentation array, stored in graph.label_key) to the
node ids (the ids of the TrackingGraph nodes).
Args:
graph(TrackingGraph): a tracking graph with a label_key on each node
Returns:
dict[int, dict[Hashable, Hashable]]: a dictionary from {time: {segmentation_id: node_id}}
Raises:
AssertionError: If two nodes in a time frame have the same segmentation_id
"""
time_to_seg_id_map: dict[int, dict[Hashable, Hashable]] = {}
for node_id, data in graph.nodes(data=True):
time = data[graph.frame_key]
seg_id = data[graph.label_key]
seg_id_to_node_id_map = time_to_seg_id_map.get(time, {})
assert seg_id not in seg_id_to_node_id_map, (
f"Segmentation ID {seg_id} occurred twice in frame {time}."
)
seg_id_to_node_id_map[seg_id] = node_id
time_to_seg_id_map[time] = seg_id_to_node_id_map
return time_to_seg_id_map
[docs]
def match_iou(
gt: TrackingGraph, pred: TrackingGraph, threshold: float = 0.6, one_to_one: bool = False
) -> list[tuple[Hashable, Hashable]]:
"""Identifies pairs of cells between gt and pred that have iou > threshold
This can return more than one match for any node
Assumes that within a frame, each object has a unique segmentation label
and that the label is recorded on each node using label_key
Args:
gt (traccuracy.TrackingGraph): Tracking data object containing graph and segmentations
pred (traccuracy.TrackingGraph): Tracking data object containing graph and segmentations
threshold (float, optional): Minimum IoU for matching cells. Defaults to 0.6.
one_to_one (optional, bool): If True, forces the mapping to be one-to-one by running
linear assignment on the thresholded iou array. Default False.
Returns:
list[(gt_node, pred_node)]: list of tuples where each tuple contains a gt node and pred node
Raises:
ValueError: gt and pred must be a TrackingData object
ValueError: GT and pred segmentations must be the same shape
"""
if not isinstance(gt, TrackingGraph) or not isinstance(pred, TrackingGraph):
raise ValueError("Input data must be a TrackingData object with a graph and segmentations")
if gt.segmentation is None or pred.segmentation is None:
raise ValueError("TrackingGraph must contain a segmentation array for IoU matching")
if gt.segmentation.shape != pred.segmentation.shape:
raise ValueError("Segmentation shapes must match between gt and pred")
mapper = []
# Get overlaps for each frame
frame_range = range(gt.segmentation.shape[0])
total = len(list(frame_range))
gt_time_to_seg_id_map = _construct_time_to_seg_id_map(gt)
pred_time_to_seg_id_map = _construct_time_to_seg_id_map(pred)
for i, t in tqdm(enumerate(frame_range), desc="Matching frames", total=total):
gt_nodes = gt.nodes_by_frame[t]
pred_nodes = pred.nodes_by_frame[t]
gt_boxes, gt_labels = graph_bbox_and_labels(gt.graph, gt_nodes, gt.label_key)
pred_boxes, pred_labels = graph_bbox_and_labels(pred.graph, pred_nodes, pred.label_key)
matches = _match_nodes(
gt.segmentation[i],
pred.segmentation[i],
gt_boxes=gt_boxes,
res_boxes=pred_boxes,
gt_labels=gt_labels,
res_labels=pred_labels,
threshold=threshold,
one_to_one=one_to_one,
)
# Construct node id tuple for each match
for gt_seg_id, pred_seg_id in zip(*matches, strict=True):
# Find node id based on time and segmentation label
gt_node = gt_time_to_seg_id_map[t][gt_seg_id]
pred_node = pred_time_to_seg_id_map[t][pred_seg_id]
mapper.append((gt_node, pred_node))
return mapper
[docs]
class IOUMatcher(Matcher):
"""Constructs a mapping between gt and pred nodes using the IoU of the segmentations
Lower values for iou_threshold will be more permissive of imperfect matches
Args:
iou_threshold (float, optional): Minimum IoU value to assign a match. Defaults to 0.6.
one_to_one (optional, bool): If True, forces the mapping to be one-to-one by running
linear assignment on the thresholded iou array. Default False.
"""
def __init__(self, iou_threshold: float = 0.6, one_to_one: bool = False):
self.iou_threshold = iou_threshold
self.one_to_one = one_to_one
# If either condition is met, matching must be one to one
if one_to_one or iou_threshold > 0.5:
self._matching_type = "one-to-one"
def _compute_mapping(
self, gt_graph: TrackingGraph, pred_graph: TrackingGraph
) -> list[tuple[Hashable, Hashable]]:
"""Computes IOU mapping for a set of graphs
Args:
gt_graph (TrackingGraph): Tracking graph object for the gt with segmentation data
pred_graph (TrackingGraph): Tracking graph object for the pred with segmentation data
Raises:
ValueError: Segmentation data must be provided for both gt and pred data
Returns:
Matched: Matched data object containing IOU mapping
"""
# Check that segmentations exist in the data
if gt_graph.segmentation is None or pred_graph.segmentation is None:
raise ValueError("Segmentation data must be provided for both gt and pred data")
mapping = match_iou(
gt_graph,
pred_graph,
threshold=self.iou_threshold,
one_to_one=self.one_to_one,
)
return mapping