Source code for traccuracy.metrics._track_overlap

"""This submodule implements routines for Track Purity (TP) and Target Effectiveness (TE) scores.

Definitions (Bise et al., 2011; Chen, 2021; Fukai et al., 2022):

- TE for a single ground truth track T^g_j is calculated by finding the predicted track T^p_k
  that overlaps with T^g_j in the largest number of the frames and then dividing
  the overlap frame counts by the total frame counts for T^g_j.
  The TE for the total dataset is calculated as the mean of TEs for all ground truth tracks,
  weighted by the length of the tracks.

- TP is defined analogously, with T^g_j and T^p_j being swapped in the definition.
"""

from __future__ import annotations

import warnings
from collections import defaultdict
from itertools import pairwise, product
from typing import TYPE_CHECKING, Any

import numpy as np

from traccuracy.matchers._base import Matched
from traccuracy.utils import get_equivalent_skip_edge

from ._base import Metric

if TYPE_CHECKING:
    import networkx as nx

    from traccuracy._tracking_graph import TrackingGraph
    from traccuracy.matchers._matched import Matched


[docs] class TrackOverlapMetrics(Metric): """Calculate metrics for longest track overlaps. - Target Effectiveness: fraction of longest overlapping prediction tracklets on each GT tracklet - Track Purity : fraction of longest overlapping GT tracklets on each prediction tracklet Args: matched_data (traccuracy.matchers.Matched): Matched object for set of GT and Pred data include_division_edges (bool, optional): If True, include edges at division. """ def __init__(self, include_division_edges: bool = True): valid_match_types = ["many-to-one", "one-to-one", "one-to-many"] super().__init__(valid_match_types) self.include_division_edges = include_division_edges def _compute( self, matched: Matched, relax_skips_gt: bool = False, relax_skips_pred: bool = False ) -> dict[str, float | np.floating[Any]]: if relax_skips_gt + relax_skips_pred == 1: warnings.warn( "Relaxing skips for either predicted or ground truth graphs" + " will still affect all overlap metrics.", stacklevel=2, ) relaxed = relax_skips_gt or relax_skips_pred gt_graph = matched.gt_graph pred_graph = matched.pred_graph gt_tracklets = gt_graph.get_tracklets(include_division_edges=self.include_division_edges) pred_tracklets = pred_graph.get_tracklets( include_division_edges=self.include_division_edges ) # if skips are not relaxed, we pass through an empty set of "relevant skips" # this means for all downstream compute, skip edges will only be matched # if the exact same skip edge exists in the other graph gt_skips = ( _get_relevant_skip_edges(gt_graph, self.include_division_edges) if relaxed else set() ) pred_skips = ( _get_relevant_skip_edges(pred_graph, self.include_division_edges) if relaxed else set() ) gt_skip_to_path_length, pred_path_to_gt_skip_map = _get_skip_path_maps( matched, gt_skips, matched.gt_pred_map ) pred_skip_to_path_length, gt_path_to_pred_skip_map = _get_skip_path_maps( matched, pred_skips, matched.pred_gt_map ) # calculate track purity and target effectiveness track_purity, _ = _calc_overlap_score( pred_tracklets, gt_tracklets, matched.pred_gt_map, pred_skip_to_path_length, # ref skips gt_path_to_pred_skip_map, # overlap_path_to_reference_skip_map pred_path_to_gt_skip_map, # reference_path_to_overlap_skip_map ) target_effectiveness, track_fractions = _calc_overlap_score( gt_tracklets, pred_tracklets, matched.gt_pred_map, gt_skip_to_path_length, # ref skips pred_path_to_gt_skip_map, # overlap_path_to_reference_skip_map gt_path_to_pred_skip_map, # reference_path_to_overlap_skip_map ) return { "track_purity": track_purity, "target_effectiveness": target_effectiveness, "track_fractions": track_fractions, }
def _calc_overlap_score( reference_tracklets: list[nx.DiGraph], overlap_tracklets: list[nx.DiGraph], overlap_reference_mapping: dict[Any, list[Any]], reference_skips: dict[tuple[int, int], int], overlap_path_to_reference_skip_map: dict[Any, dict[str, Any]], reference_path_to_overlap_skip_map: dict[Any, dict[str, Any]], ) -> tuple[float | np.floating[Any], float | np.floating[Any]]: """Get weighted/unweighted fraction of reference_tracklets overlapped by overlap_tracklets. The weighted average is calculated as the total number of maximally overlapping edges divided by the total number of edges in the reference tracklets. The unweighted average is calculated as the mean of the fraction of maximally overlapping edges for each reference tracklet. Args: reference_tracklets (List[TrackingGraph]): The reference tracklets overlap_tracklets (List[TrackingGraph]): The tracklets that overlap overlap_reference_mapping (Dict[Any, List[Any]]): Mapping as a dict from the overlap tracklet nodes to the reference tracklet nodes reference_skips (Dict[Tuple[int, int], int]): Mapping of skip edges in the reference tracklets to their lengths overlap_path_to_reference_skip_map (Dict[Any, Dict[str, Any]]): Mapping from nodes in overlap tracklet equivalent paths to the edge they are part of and the reference skip edge they cover reference_path_to_overlap_skip_map (Dict[Any, Dict[str, Any]]): Mapping from nodes in reference tracklet equivalent paths to the edge they are part of and the overlapping skip edge they cover Returns: tuple[float | np.floating[Any], float | np.floating[Any]]: A tuple containing the weighted and unweighted averages of the overlap fractions. """ max_overlap = 0 total_count = 0 track_fractions = [] # maps each edge to their tracklet index overlap_edge_to_tid = { edge: i for i in range(len(overlap_tracklets)) for edge in overlap_tracklets[i].edges() } for reference_tracklet in reference_tracklets: tracklet_length = len(reference_tracklet.edges()) # maps overlap track ID to the number of edges of the current reference tracklet # that overlap overlapping_id_to_count: dict[int, int] = defaultdict(lambda: 0) for ref_src, ref_tgt in reference_tracklet.edges(): if (ref_src, ref_tgt) in reference_skips: # if this is a skip edge, there is some equivalent path in the overlaps # let's find an edge on that path and update the count for node in overlap_path_to_reference_skip_map: path_info = overlap_path_to_reference_skip_map[node] found = False for i, skip_edge in enumerate(path_info["skip_edge"]): if skip_edge == (ref_src, ref_tgt): edge_in_path = path_info["edge_in_path"][i] overlapping_id_to_count[overlap_edge_to_tid[edge_in_path]] += 1 found = True break if found: break continue # this edge is part of an equivalent path for an overlap skip edge # we need to find that skip edge and update its count by 1 if ( ref_src in reference_path_to_overlap_skip_map and ref_tgt in reference_path_to_overlap_skip_map ): # both nodes are in the path, but one of them might be part of multiple skip # edges. We therefore find the specific edge that both ref_src and ref_tgt are # part of skip_info = reference_path_to_overlap_skip_map[ref_src] edge_in_path = skip_info["edge_in_path"] for i, edge in enumerate(edge_in_path): if edge[0] == ref_src and edge[1] == ref_tgt: equivalent_skip_edge = skip_info["skip_edge"][i] overlapping_id_to_count[overlap_edge_to_tid[equivalent_skip_edge]] += 1 break overlap_src = overlap_reference_mapping.get(ref_src, []) overlap_tgt = overlap_reference_mapping.get(ref_tgt, []) # any edge that has both nodes in an overlap tracklet # could be overlapping for src, tgt in product(overlap_src, overlap_tgt): if (src, tgt) in overlap_edge_to_tid: overlapping_id_to_count[overlap_edge_to_tid[(src, tgt)]] += 1 total_count += tracklet_length tracklet_overlap = max(overlapping_id_to_count.values(), default=0) max_overlap += tracklet_overlap if tracklet_length: track_fractions.append(tracklet_overlap / tracklet_length) weighted_average = max_overlap / total_count if total_count > 0 else np.nan unweighted_average = np.mean(track_fractions) if track_fractions else np.nan return weighted_average, unweighted_average def _get_relevant_skip_edges( graph: TrackingGraph, include_division_edges: bool ) -> set[tuple[Any, Any]]: """Get relevant skip edges from the graph, potentially including division edges. Args: graph (TrackingGraph): graph to extract skip edges from include_division_edges (bool): True if parent-daughter edges should be included, otherwise False. Returns: set[tuple[Any, Any]]: skip edges on graph with/without division edges. """ skips = graph.get_skip_edges() if include_division_edges: return skips # if division edges are not included, we only consider skips that are not division edges for skip_src, skip_tgt in skips.copy(): if graph.graph.out_degree(skip_src) > 1: # type: ignore skips.remove((skip_src, skip_tgt)) return skips def _get_skip_path_maps( matched: Matched, skips: set[tuple[Any, Any]], skip_to_other_map: dict[Any, list[Any]], ) -> tuple[dict[tuple[Any, Any], int], dict[Any, dict[str, list[tuple[Any, Any]]]]]: """Get information about equivalent paths for skip edges. For each skip edge, find the equivalent path in the matched graph and return a mapping of skip edges to their equivalent path lengths. Also returns a mapping from nodes along equivalent paths to the edge they are part of and the skip edge they cover. Args: matched (traccuracy.matchers.Matched): The matched object containing the graphs. skips (set[tuple[Any, Any]]): Set of skip edges to process. skip_to_other_map (dict[Any, list[Any]]): Mapping of nodes in graph with skips to nodes in the other graph. Returns: tuple[dict[tuple[Any, Any], int], dict[Any, dict[str, Any]]]: A tuple containing: - A dictionary mapping skip edges to their equivalent path lengths. - A dictionary mapping nodes in equivalent paths to the edge they are part of and the skip edge they cover. """ skip_to_equivalent_path_length = {} path_node_to_skip_map: dict[Any, dict[str, list[tuple[Any, Any]]]] = defaultdict( lambda: defaultdict(list) ) for skip_src, skip_tgt in skips: matched_src = skip_to_other_map.get(skip_src, []) matched_tgt = skip_to_other_map.get(skip_tgt, []) for possible_src, possible_tgt in product(matched_src, matched_tgt): equivalent_path = get_equivalent_skip_edge( matched, skip_src, skip_tgt, possible_src, possible_tgt ) if equivalent_path: for edge_src, edge_tgt in pairwise(equivalent_path): path_node_to_skip_map[edge_src]["edge_in_path"].append((edge_src, edge_tgt)) path_node_to_skip_map[edge_src]["skip_edge"].append((skip_src, skip_tgt)) path_node_to_skip_map[edge_tgt]["edge_in_path"].append((edge_src, edge_tgt)) path_node_to_skip_map[edge_tgt]["skip_edge"].append((skip_src, skip_tgt)) skip_to_equivalent_path_length[(skip_src, skip_tgt)] = len(equivalent_path) - 1 return skip_to_equivalent_path_length, path_node_to_skip_map