Source code for traccuracy.metrics._ctc

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import numpy as np

from traccuracy._tracking_graph import EdgeFlag, NodeFlag
from traccuracy.matchers._base import Matched
from traccuracy.track_errors._ctc import VALID_MATCHING_TYPES, evaluate_ctc_events

from ._base import Metric

if TYPE_CHECKING:
    from traccuracy.matchers._matched import Matched


[docs] class AOGMMetrics(Metric): """Computes the Acyclic Oriented Graph Measure (AOGM), along with the error counts The AOGM metric is a generalized graph measure that allows users to define their own error weights for each type of node and edge error. The AOGM is simply the weighted sum of all errors. These metrics are written assuming that the ground truth annotations are dense. If that is not the case, interpret the numbers carefully. Consider eliminating metrics that use the number of false positives. Args: vertex_ns_weight (float): Weight for vertex/node non-split errors. Defaults to 1 vertex_fp_weight (float): Weight for false positive vertex/node errors. Defaults to 1 vertex_fn_weight (float): Weight for false negative vertex/node errors. Defaults to 1 edge_fp_weight (float): Weight for false positive edge errors. Defaults to 1 edge_fn_weight (float): Weight for false negative edge errors. Defaults to 1 edge_ws_weight (float): Weight for wrong semantic edge errors. Defaults to 1 """ def __init__( self, vertex_ns_weight: float = 1, vertex_fp_weight: float = 1, vertex_fn_weight: float = 1, edge_fp_weight: float = 1, edge_fn_weight: float = 1, edge_ws_weight: float = 1, ) -> None: super().__init__(VALID_MATCHING_TYPES) self.v_weights = { "ns": vertex_ns_weight, "fp": vertex_fp_weight, "fn": vertex_fn_weight, } self.e_weights = { "fp": edge_fp_weight, "fn": edge_fn_weight, "ws": edge_ws_weight, } def _compute( self, data: Matched, relax_skips_gt: bool = False, relax_skips_pred: bool = False ) -> dict[str, float]: if relax_skips_gt or relax_skips_pred: warnings.warn( "CTC metrics do not support relaxing skip edges. " "Ignoring relax_skips_gt and relax_skips_pred.", stacklevel=2, ) evaluate_ctc_events(data) vertex_error_counts: dict[str, float] = { "ns": len(data.pred_graph.get_nodes_with_flag(NodeFlag.NON_SPLIT)), "fp": len(data.pred_graph.get_nodes_with_flag(NodeFlag.CTC_FALSE_POS)), "fn": len(data.gt_graph.get_nodes_with_flag(NodeFlag.CTC_FALSE_NEG)), } edge_error_counts: dict[str, float] = { "ws": len(data.pred_graph.get_edges_with_flag(EdgeFlag.WRONG_SEMANTIC)), "fp": len(data.pred_graph.get_edges_with_flag(EdgeFlag.CTC_FALSE_POS)), "fn": len(data.gt_graph.get_edges_with_flag(EdgeFlag.CTC_FALSE_NEG)), } error_sum = get_weighted_error_sum( vertex_error_counts, edge_error_counts, self.v_weights["ns"], self.v_weights["fp"], self.v_weights["fn"], self.e_weights["fp"], self.e_weights["fn"], self.e_weights["ws"], ) return { "AOGM": error_sum, "fp_nodes": vertex_error_counts["fp"], "fn_nodes": vertex_error_counts["fn"], "ns_nodes": vertex_error_counts["ns"], "fp_edges": edge_error_counts["fp"], "fn_edges": edge_error_counts["fn"], "ws_edges": edge_error_counts["ws"], }
[docs] class CTCMetrics(AOGMMetrics): """Computes the original Cell Tracking Challenging metrics: TRA, DET, LNK. These metrics are based on the more general AOGM metric. - DET: Assesses detection performance - LNK: Assesses linking performance by measuring only edge errors - TRA: Assesses both detection and tracking performance These metrics are written assuming that the ground truth annotations are dense. If that is not the case, interpret the numbers carefully. Consider eliminating metrics that use the number of false positives. """ def __init__(self) -> None: vertex_weight_ns = 5 vertex_weight_fn = 10 vertex_weight_fp = 1 edge_weight_fp = 1 edge_weight_fn = 1.5 edge_weight_ws = 1 super().__init__( vertex_ns_weight=vertex_weight_ns, vertex_fp_weight=vertex_weight_fp, vertex_fn_weight=vertex_weight_fn, edge_fp_weight=edge_weight_fp, edge_fn_weight=edge_weight_fn, edge_ws_weight=edge_weight_ws, ) def _compute( self, data: Matched, relax_skips_gt: bool = False, relax_skips_pred: bool = False ) -> dict[str, float]: if relax_skips_gt or relax_skips_pred: warnings.warn( "CTC metrics do not support relaxing skip edges. " "Ignoring relax_skips_gt and relax_skips_pred.", stacklevel=2, ) errors = super()._compute(data) gt_graph = data.gt_graph.graph n_nodes = gt_graph.number_of_nodes() n_edges = gt_graph.number_of_edges() tra = self._get_tra(errors, n_nodes, n_edges) errors["TRA"] = tra det = self._get_det(errors, n_nodes) errors["DET"] = det lnk = self._get_lnk(errors, n_edges) errors["LNK"] = lnk return errors def _get_tra(self, errors: dict[str, float], n_nodes: int, n_edges: int) -> float: """Get the TRA score from the error counts and total number of gt nodes and edges Args: errors (dict[str, int]): A dictionary containing the AOGM n_nodes (int): The number of nodes in the ground truth graph n_edges (int): The number of edges in the ground truth graph Returns: float: the TRA score, computed with the CTC metric weights, or np.nan if the AOGM_0 is 0 """ aogm_0 = n_nodes * self.v_weights["fn"] + n_edges * self.e_weights["fn"] if aogm_0 == 0: warnings.warn( UserWarning( f"AOGM0 is 0 - cannot compute TRA from GT graph with {n_nodes} nodes and" + f" {n_edges} edges with {self.v_weights['fn']} vertex FN weight and" + f" {self.e_weights['fn']} edge FN weight" ), stacklevel=1, ) return np.nan aogm = errors["AOGM"] tra = 1 - min(aogm, aogm_0) / aogm_0 return tra def _get_det(self, errors: dict[str, float], n_nodes: int) -> float: """Get the DET score from the error counts and total number of gt nodes Args: errors (dict[str, int]): A dictionary containing the counts of each type of node error (fp_nodes, fn_nodes, ns_nodes) n_nodes (int): The number of nodes in the ground truth graph Returns: float: the DET score, computed with the CTC metric weights, or np.nan if there are no nodes in the gt graph """ if n_nodes == 0: warnings.warn( UserWarning("No nodes in the GT graph, cannot compute DET."), stacklevel=1, ) return np.nan aogmd_0 = n_nodes * self.v_weights["fn"] aogmd = get_weighted_vertex_error_sum( { "ns": errors["ns_nodes"], "fp": errors["fp_nodes"], "fn": errors["fn_nodes"], }, self.v_weights["ns"], self.v_weights["fp"], self.v_weights["fn"], ) det = 1 - min(aogmd, aogmd_0) / aogmd_0 return det def _get_lnk(self, errors: dict[str, float], n_edges: int) -> float: """Get the DET score from the error counts and total number of gt edges Args: errors (dict[str, int]): A dictionary containing the counts of each type of edge error (fp_edges, fn_edges, ws_edges) n_edges (int): The number of edges in the ground truth graph Returns: float: the TRA score, computed with the CTC metric weights, or np.nan if there are no edges in the GT graph """ if n_edges == 0: warnings.warn( UserWarning("No edges in the GT graph, cannot compute LNK."), stacklevel=1, ) return np.nan aogma_0 = n_edges * self.e_weights["fn"] aogma = get_weighted_edge_error_sum( { "fp": errors["fp_edges"], "fn": errors["fn_edges"], "ws": errors["ws_edges"], }, self.e_weights["fp"], self.e_weights["fn"], self.e_weights["ws"], ) lnk = 1 - min(aogma, aogma_0) / aogma_0 return lnk
[docs] def get_weighted_vertex_error_sum( vertex_error_counts: dict[str, float], vertex_ns_weight: float = 1, vertex_fp_weight: float = 1, vertex_fn_weight: float = 1, ) -> float: vertex_ns_count = vertex_error_counts["ns"] vertex_fp_count = vertex_error_counts["fp"] vertex_fn_count = vertex_error_counts["fn"] vertex_error_sum = ( vertex_ns_weight * vertex_ns_count + vertex_fp_weight * vertex_fp_count + vertex_fn_weight * vertex_fn_count ) return vertex_error_sum
[docs] def get_weighted_edge_error_sum( edge_error_counts: dict[str, float], edge_fp_weight: float = 1, edge_fn_weight: float = 1, edge_ws_weight: float = 1, ) -> float: edge_fp_count = edge_error_counts["fp"] edge_fn_count = edge_error_counts["fn"] edge_ws_count = edge_error_counts["ws"] edge_error_sum = ( edge_fp_weight * edge_fp_count + edge_fn_weight * edge_fn_count + edge_ws_weight * edge_ws_count ) return edge_error_sum
[docs] def get_weighted_error_sum( vertex_error_counts: dict[str, float], edge_error_counts: dict[str, float], vertex_ns_weight: float = 1, vertex_fp_weight: float = 1, vertex_fn_weight: float = 1, edge_fp_weight: float = 1, edge_fn_weight: float = 1, edge_ws_weight: float = 1, ) -> float: vertex_error_sum = get_weighted_vertex_error_sum( vertex_error_counts, vertex_ns_weight, vertex_fp_weight, vertex_fn_weight ) edge_error_sum = get_weighted_edge_error_sum( edge_error_counts, edge_fp_weight, edge_fn_weight, edge_ws_weight ) return vertex_error_sum + edge_error_sum