Source code for traccuracy.matchers._ctc

from __future__ import annotations

from typing import TYPE_CHECKING

from tqdm import tqdm

if TYPE_CHECKING:
    from collections.abc import Hashable

    from traccuracy._tracking_graph import TrackingGraph

from ._base import Matcher
from ._compute_overlap import get_labels_with_overlap, graph_bbox_and_labels


[docs] class CTCMatcher(Matcher): """Match graph nodes based on measure used in cell tracking challenge benchmarking. A computed marker (segmentation) is matched to a reference marker if the computed marker covers a majority of the reference marker. Each reference marker can therefore only be matched to one computed marker, but multiple reference markers can be assigned to a single computed marker. See https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0144959 for complete details. """ # CTC can return many-to-one or one-to-one _matching_type = None def _compute_mapping( self, gt_graph: TrackingGraph, pred_graph: TrackingGraph ) -> list[tuple[Hashable, Hashable]]: """Run ctc matching Args: gt_graph (TrackingGraph): Tracking graph object for the gt pred_graph (TrackingGraph): Tracking graph object for the pred Returns: traccuracy.matchers.Matched: Matched data object containing the CTC mapping Raises: ValueError: if GT and pred segmentations are None or are not the same shape """ gt = gt_graph pred = pred_graph gt_label_key = gt_graph.label_key pred_label_key = pred_graph.label_key G_gt, mask_gt = gt, gt.segmentation G_pred, mask_pred = pred, pred.segmentation if mask_gt is None or mask_pred is None: raise ValueError("Segmentation is None, cannot perform matching") if mask_gt.shape != mask_pred.shape: raise ValueError("Segmentation shapes must match between gt and pred") mapping: list[tuple] = [] # Get overlaps for each frame if gt.start_frame is None or gt.end_frame is None: return mapping for i, t in enumerate( tqdm( range(gt.start_frame, gt.end_frame), desc="Matching frames", ) ): gt_frame = mask_gt[i] pred_frame = mask_pred[i] gt_frame_nodes = gt.nodes_by_frame[t] pred_frame_nodes = pred.nodes_by_frame[t] # get the labels for this frame gt_label_to_id = { G_gt.graph.nodes[node][gt_label_key]: node for node in gt_frame_nodes if gt_label_key in G_gt.graph.nodes[node] } pred_label_to_id = { G_pred.graph.nodes[node][pred_label_key]: node for node in pred_frame_nodes if pred_label_key in G_pred.graph.nodes[node] } gt_boxes, gt_labels = graph_bbox_and_labels(gt.graph, gt_frame_nodes, gt_label_key) pred_boxes, pred_labels = graph_bbox_and_labels( pred.graph, pred_frame_nodes, pred_label_key ) overlaps = get_labels_with_overlap( gt_frame, pred_frame, gt_boxes=gt_boxes, res_boxes=pred_boxes, gt_labels=gt_labels, res_labels=pred_labels, overlap="iogt", ) # Switch from segmentation ids to node ids for gt_label, pred_label, iogt in overlaps: if iogt > 0.5: # Skip if either seg label has no corresponding graph node # (e.g. node removed by border_margin filtering) if gt_label not in gt_label_to_id or pred_label not in pred_label_to_id: continue mapping.append((gt_label_to_id[gt_label], pred_label_to_id[pred_label])) return mapping