Source code for traccuracy.matchers._point_seg

from __future__ import annotations

from typing import TYPE_CHECKING, cast

from ._base import Matcher

if TYPE_CHECKING:
    from collections.abc import Hashable
    from typing import Any

    import numpy as np

    from traccuracy._tracking_graph import TrackingGraph


[docs] class PointSegMatcher(Matcher): """A matcher that constructs a mapping from a set of points to a segmentation array by determining if a point falls within a segmentation label. Either the predicted data or the ground truth can contain a segmentation array, but not both. The matcher will map many points to a single segmentation label. """ def _compute_mapping( self, gt_graph: TrackingGraph, pred_graph: TrackingGraph ) -> list[tuple[Any, Any]]: # Identify which data has segmentations # s data has seg, p data has points # Fail if both have segmentations or segmentations missing if gt_graph.segmentation is not None and pred_graph.segmentation is not None: raise ValueError( "Both datasets have segmentations. " "Please provide only one dataset with segmentations." ) elif gt_graph.segmentation is not None: seg_source = "gt" s_graph = gt_graph p_graph = pred_graph elif pred_graph.segmentation is not None: seg_source = "pred" s_graph = pred_graph p_graph = gt_graph else: raise ValueError("Data provided does not contain segmentations.") # Cast s_graph.segmentation and s_graph.label_key to eliminate none possibility s_graph.segmentation = cast("np.ndarray", s_graph.segmentation) s_graph.label_key = cast("str", s_graph.label_key) mapping: list[tuple[Any, Any]] = [] if s_graph.start_frame is None or s_graph.end_frame is None: return mapping map_p_nodes, map_s_nodes = [], [] for frame in range(s_graph.start_frame, s_graph.end_frame): # Get mapping from p_nodes to s_seg_ids p_nodes = list(p_graph.nodes_by_frame.get(frame, [])) p_locations = [p_graph.get_location(node) for node in p_nodes] frame_map = match_point_to_seg(p_nodes, p_locations, s_graph.segmentation[frame]) # Construct lookup from seg_id to s_node id seg_to_snode = {} for node in s_graph.nodes_by_frame[frame]: seg_id = s_graph.nodes[node][s_graph.label_key] seg_to_snode[seg_id] = node # Convert s_seg_ids to s_nodes for p_node, seg_id in frame_map.items(): map_p_nodes.append(p_node) map_s_nodes.append(seg_to_snode[seg_id]) # Construct mapping from tuples of two lists so order gt -> pred is correct if seg_source == "gt": mapping = list(zip(map_s_nodes, map_p_nodes, strict=False)) else: mapping = list(zip(map_p_nodes, map_s_nodes, strict=False)) return mapping
[docs] def match_point_to_seg( node_ids: list[Hashable], locs: list[list[float] | np.ndarray | tuple[float]], seg: np.ndarray ) -> dict[Hashable, int]: """For a single timepoint, identify the segmentation ids which a set of points index into Args: node_ids (list[Hashable]): A list of node ids locs (list[list[float]]): A list of locations corresponding to the list of node ids seg (np.ndarray): A 2D segmentation array Returns: dict[Hashable, int]: A dictionary mapping from node_id to segmentation value """ node_to_segid = {} for node, loc in zip(node_ids, locs, strict=False): # Check if loc is inside of segmentation after casting to int for indexing seg_val = seg[tuple([int(x) for x in loc])] if seg_val != 0: node_to_segid[node] = seg_val return node_to_segid