Source code for traccuracy.track_errors._ctc

from __future__ import annotations

import logging
import warnings
from typing import TYPE_CHECKING

from tqdm import tqdm

from traccuracy._tracking_graph import EdgeFlag, NodeFlag

if TYPE_CHECKING:
    from traccuracy.matchers._matched import Matched

logger = logging.getLogger(__name__)


[docs] def evaluate_ctc_events(matched_data: Matched) -> None: """Annotates ground truth and predicted graph with node and edge error types Annotations are made in place """ get_vertex_errors(matched_data) get_edge_errors(matched_data)
[docs] def get_vertex_errors(matched_data: Matched) -> None: """Count vertex errors and assign class to each comp/gt node. Parameters ---------- matched_data: traccuracy.matchers.Matched Matched data object containing gt and pred graphs with their associated mapping """ comp_graph = matched_data.pred_graph gt_graph = matched_data.gt_graph dict_mapping = matched_data.pred_gt_map if comp_graph.ctc_node_errors + gt_graph.ctc_node_errors == 1: graph_with_errors = "pred graph" if comp_graph.ctc_node_errors else "GT graph" raise ValueError( f"Only {graph_with_errors} has node errors annotated. " "Please ensure either both or neither " + "of the graphs have traccuracy annotations before running metrics." ) if comp_graph.ctc_node_errors and gt_graph.ctc_node_errors: warnings.warn("Node errors already calculated. Skipping graph annotation", stacklevel=2) return # will flip this when we come across the vertex in the mapping comp_graph.set_flag_on_all_nodes(NodeFlag.CTC_FALSE_POS, True) gt_graph.set_flag_on_all_nodes(NodeFlag.CTC_FALSE_NEG, True) ns_count = 0 for pred_id in tqdm(dict_mapping, desc="Evaluating nodes"): gt_ids = dict_mapping[pred_id] if len(gt_ids) == 1: gid = gt_ids[0] comp_graph.set_flag_on_node(pred_id, NodeFlag.CTC_TRUE_POS, True) comp_graph.remove_flag_from_node(pred_id, NodeFlag.CTC_FALSE_POS) gt_graph.remove_flag_from_node(gid, NodeFlag.CTC_FALSE_NEG) gt_graph.set_flag_on_node(gid, NodeFlag.CTC_TRUE_POS, True) elif len(gt_ids) > 1: comp_graph.set_flag_on_node(pred_id, NodeFlag.NON_SPLIT, True) comp_graph.remove_flag_from_node(pred_id, NodeFlag.CTC_FALSE_POS) # number of split operations that would be required to correct the vertices ns_count += len(gt_ids) - 1 for gt_id in gt_ids: gt_graph.remove_flag_from_node(gt_id, NodeFlag.CTC_FALSE_NEG) # Record presence of annotations on the TrackingGraph comp_graph.ctc_node_errors = True gt_graph.ctc_node_errors = True
[docs] def get_edge_errors(matched_data: Matched) -> None: comp_graph = matched_data.pred_graph gt_graph = matched_data.gt_graph node_mapping = matched_data.mapping if comp_graph.ctc_edge_errors + gt_graph.ctc_edge_errors == 1: graph_with_errors = "pred graph" if comp_graph.ctc_edge_errors else "GT graph" raise ValueError( f"Only {graph_with_errors} has edge errors annotated. " "Please ensure either both or neither " + "of the graphs have traccuracy annotations before running metrics." ) if comp_graph.ctc_edge_errors and gt_graph.ctc_edge_errors: warnings.warn("Edge errors already calculated. Skipping graph annotation", stacklevel=2) return # Node errors must already be annotated if not comp_graph.ctc_node_errors and not gt_graph.ctc_node_errors: logger.info("Node errors have not been annotated. Running node annotation.", stacklevel=2) get_vertex_errors(matched_data) comp_tp_nodes = comp_graph.get_nodes_with_flag(NodeFlag.CTC_TRUE_POS) induced_graph = comp_graph.graph.subgraph(comp_tp_nodes) gt_comp_mapping = {gt: comp for gt, comp in node_mapping if comp in induced_graph} comp_gt_mapping = {comp: gt for gt, comp in node_mapping if comp in induced_graph} # intertrack edges = connection between parent and daughter for graph in [comp_graph, gt_graph]: for parent in graph.get_divisions(): for daughter in graph.graph.successors(parent): graph.set_flag_on_edge((parent, daughter), EdgeFlag.INTERTRACK_EDGE, True) for merge in graph.get_merges(): for parent in graph.graph.predecessors(merge): graph.set_flag_on_edge((parent, merge), EdgeFlag.INTERTRACK_EDGE, True) # fp edges - edges in induced_graph that aren't in gt_graph for edge in tqdm(induced_graph.edges, "Evaluating FP edges"): source, target = edge[0], edge[1] source_gt_id = comp_gt_mapping[source] target_gt_id = comp_gt_mapping[target] expected_gt_edge = (source_gt_id, target_gt_id) if expected_gt_edge not in gt_graph.edges: comp_graph.set_flag_on_edge(edge, EdgeFlag.CTC_FALSE_POS, True) else: # check if semantics are correct is_parent_gt = gt_graph.edges[expected_gt_edge].get(EdgeFlag.INTERTRACK_EDGE, False) is_parent_comp = comp_graph.edges[edge].get(EdgeFlag.INTERTRACK_EDGE, False) if is_parent_gt != is_parent_comp: comp_graph.set_flag_on_edge(edge, EdgeFlag.WRONG_SEMANTIC, True) # fn edges - edges in gt_graph that aren't in induced graph for edge in tqdm(gt_graph.edges, "Evaluating FN edges"): source, target = edge[0], edge[1] # this edge is adjacent to an edge we didn't detect, so it definitely is an fn if gt_graph.nodes[source].get(NodeFlag.CTC_FALSE_NEG, False) or gt_graph.nodes[target].get( NodeFlag.CTC_FALSE_NEG, False ): gt_graph.set_flag_on_edge(edge, EdgeFlag.CTC_FALSE_NEG, True) continue source_comp_id = gt_comp_mapping.get(source, None) target_comp_id = gt_comp_mapping.get(target, None) if source_comp_id is None or target_comp_id is None: gt_graph.set_flag_on_edge(edge, EdgeFlag.CTC_FALSE_NEG, True) else: expected_comp_edge = (source_comp_id, target_comp_id) if expected_comp_edge not in induced_graph.edges: gt_graph.set_flag_on_edge(edge, EdgeFlag.CTC_FALSE_NEG, True) gt_graph.ctc_edge_errors = True comp_graph.ctc_edge_errors = True