Source code for traccuracy.metrics._complete_tracks

from __future__ import annotations

import itertools
import warnings
from typing import TYPE_CHECKING

import numpy as np

from traccuracy._tracking_graph import EdgeFlag, NodeFlag
from traccuracy.track_errors._basic import classify_basic_errors
from traccuracy.track_errors._ctc import evaluate_ctc_events
from traccuracy.track_errors._divisions import evaluate_division_events

from ._base import Metric

if TYPE_CHECKING:
    from collections.abc import Hashable

    from traccuracy.matchers import Matched


[docs] class CompleteTracks(Metric): """The fraction of tracklets and lineages that are completely correctly reconstructed. If the reconstruction continues beyond the ground truth track, this is NOT counted as incorrect, nor are false positive tracks penalized, making this suitable for evaluating with sparse ground truth annotations. If a False Positive Division occurs within the ground truth track (or, for the CTC errors, a wrong semantic edge), this IS counted as incorrect. Args: error_type (str, optional): Whether to use "basic" or "ctc" errors for computing if tracks are correct or not. Defaults to "basic". The compute function returns a results dictionary with the following entries: - `total_lineages` - the number of connected components in the ground truth graph - `correct_lineages` - the number of fully correct connected components - `complete_lineages` - `correct_lineages` / `total_lineages`, or np.nan if `total_lineages` is 0 - `total_tracklets` - the number of tracklets in the ground truth graph, defined as the connected components of the graph after division edges are removed. Division edges are not included in the tracklets, or counted at all in the tracklet metrics. - `correct_tracklets` - the number of fully correct tracklets - `complete_tracklets` - `correct_tracklets` / `total_tracklets`, or np.nan if `total_tracklets` is 0 """ def __init__(self, error_type: str = "basic"): valid_matches = ["one-to-one", "many-to-one"] super().__init__(valid_matches) if error_type not in ["ctc", "basic"]: raise ValueError(f"Unrecognized error type {error_type}. Should be 'ctc' or 'basic'") self.error_type = error_type def _compute( self, matched: Matched, relax_skips_gt: bool = False, relax_skips_pred: bool = False ) -> dict: """Computes the fraction of fully correct tracklets and lineages in the matched object. If skip edges are relaxed in one graph, then skip_tp edges in the other graph are counted as correct, along with nodes between the skip_tp edges in that graph. Args: matched (traccuracy.matchers.Matched): Matched data object to compute metrics on relax_skips_gt (bool): If True, the metric will check if skips in the ground truth graph have an equivalent multi-edge path in predicted graph relax_skips_pred (bool): If True, the metric will check if skips in the predicted graph have an equivalent multi-edge path in ground truth graph Returns: dict: A results dictionary with the following entries: - `total_lineages` - the number of connected components in the ground truth graph - `correct_lineages` - the number of fully correct connected components - `complete_lineages` - `correct_lineages` / `total_lineages`, or np.nan if `total_lineages` is 0 - `total_tracklets` - the number of tracklets in the ground truth graph, defined as the connected components of the graph after division edges are removed. Division edges are not included in the tracklets, or counted at all in the tracklet metrics. - `correct_tracklets` - the number of fully correct tracklets - `complete_tracklets` - `correct_tracklets` / `total_tracklets`, or np.nan if `total_tracklets` is 0 """ if self.error_type == "basic": classify_basic_errors( matched, relax_skips_gt=relax_skips_gt, relax_skips_pred=relax_skips_pred ) evaluate_division_events( matched, relax_skips_gt=relax_skips_gt, relax_skips_pred=relax_skips_pred ) else: if relax_skips_gt or relax_skips_pred: warnings.warn( "CTC metrics do not support relaxing skip edges. " "Ignoring relax_skips_gt and relax_skips_pred.", stacklevel=2, ) evaluate_ctc_events(matched) total_tracklets = 0 total_lineages = 0 correct_tracklets = 0 correct_lineages = 0 # Only directly considering gt graph # Entirely FP lineages are not penalized # Nor are lineages continuing beyond gt lineage gt_nxgraph = matched.gt_graph.graph lineage_starts = [node for node, in_degree in gt_nxgraph.in_degree() if in_degree == 0] # type: ignore for lineage_start in lineage_starts: # Within each lineage, find all division edges and daughters that start tracklets tracklet_starts = [lineage_start] div_edges = [] curr_nodes = [lineage_start] while len(curr_nodes) > 0: next_succs = [] for succ in curr_nodes: daughters = list(gt_nxgraph.successors(succ)) next_succs.extend(daughters) if len(daughters) == 2: tracklet_starts.extend(daughters) div_edges.extend([(succ, daught) for daught in daughters]) curr_nodes = next_succs subtracklets_correct = [ self._check_tracklet_correct( tracklet_start, matched, relax_skips_gt=relax_skips_gt, relax_skips_pred=relax_skips_pred, ) for tracklet_start in tracklet_starts ] div_edges_correct = [ self._check_gt_edge_correct( div_edge, matched, relax_skips_gt=relax_skips_gt, relax_skips_pred=relax_skips_pred, ) for div_edge in div_edges ] lineage_correct = all(subtracklets_correct) and all(div_edges_correct) total_tracklets += len(tracklet_starts) correct_tracklets += sum(subtracklets_correct) total_lineages += 1 correct_lineages += lineage_correct return { "total_lineages": total_lineages, "total_tracklets": total_tracklets, "correct_lineages": correct_lineages, "correct_tracklets": correct_tracklets, "complete_lineages": correct_lineages / total_lineages if total_lineages > 0 else np.nan, "complete_tracklets": correct_tracklets / total_tracklets if total_tracklets > 0 else np.nan, } def _check_tracklet_correct( self, start_node: Hashable, matched: Matched, relax_skips_gt: bool, relax_skips_pred: bool ) -> bool: if not self._check_gt_node_correct(start_node, matched, relax_skips_pred=relax_skips_pred): return False out_edges = list(matched.gt_graph.graph.out_edges(start_node)) while len(out_edges) == 1: out_edge = out_edges[0] if not self._check_gt_edge_correct(out_edge, matched, relax_skips_gt, relax_skips_pred): return False curr_node = out_edge[1] if not self._check_gt_node_correct( curr_node, matched, relax_skips_pred=relax_skips_pred ): return False out_edges = list(matched.gt_graph.graph.out_edges(curr_node)) return True def _check_gt_node_correct( self, node: Hashable, matched: Matched, relax_skips_pred: bool ) -> bool: node_tp = NodeFlag.TRUE_POS if self.error_type == "basic" else NodeFlag.CTC_TRUE_POS gt_track = matched.gt_graph # check if this gt node is a true pos if node_tp in gt_track.nodes[node]: # check if it is not matched to a FP-DIV, if applicable if self.error_type == "basic": matched_nodes = matched.get_gt_pred_matches(node) for pred_node in matched_nodes: if NodeFlag.FP_DIV in matched.pred_graph.nodes[pred_node]: return False return True else: # if skip edges are relaxed, check if the node is between skip tps # (enough to check that one prev edge is a skip TP) if relax_skips_pred: for prev_edge in gt_track.graph.in_edges(node): if EdgeFlag.SKIP_TRUE_POS in gt_track.edges[prev_edge]: return True # it's not a TP or between skip edges, so it's just wrong return False def _check_gt_edge_correct( self, edge: tuple[Hashable, Hashable], matched: Matched, relax_skips_gt: bool, relax_skips_pred: bool, ) -> bool: gt_track = matched.gt_graph pred_track = matched.pred_graph edge_data = gt_track.edges[edge] # check if it is a TP if self.error_type == "ctc": # the ctc errors don't annotate edge TPs, so instead we check for absence of # all the error types. Wrong semantic are only annotated on the pred graph, # so we need to find the matched edge and check it tp = True if EdgeFlag.CTC_FALSE_NEG in edge_data: tp = False else: matched_sources = matched.get_gt_pred_matches(edge[0]) matched_targets = matched.get_gt_pred_matches(edge[1]) matched_edges = [ (source, target) for source, target in itertools.product(matched_sources, matched_targets) if pred_track.graph.has_edge(source, target) ] for matched_edge in matched_edges: if EdgeFlag.WRONG_SEMANTIC in pred_track.graph.edges[matched_edge]: tp = False break if tp: return True else: if EdgeFlag.TRUE_POS in edge_data: return True is_skip_edge = gt_track.is_skip_edge(edge) if is_skip_edge and relax_skips_gt and EdgeFlag.SKIP_TRUE_POS in edge_data: return True if (not is_skip_edge) and relax_skips_pred and EdgeFlag.SKIP_TRUE_POS in edge_data: return True return False