Source code for traccuracy.metrics._cca

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import networkx as nx
import numpy as np
from scipy.sparse import coo_array
from skan.csr import PathGraph, summarize

from traccuracy.metrics._base import Metric

if TYPE_CHECKING:
    from traccuracy._tracking_graph import TrackingGraph
    from traccuracy.matchers._base import Matched


[docs] class CellCycleAccuracy(Metric): """The CCA metric captures the ability of a method to identify a distribution of cell cycle lengths that matches the distribution present in the ground truth. The evaluation is done on distributions and therefore does not require a matching of solution to the ground truth. It ranges from [0,1] with higher values indicating better performance. This metric is part of the biologically inspired metrics introduced by the CTC and defined in Ulman 2017. """ def __init__(self) -> None: # CCA does not use matching and therefore any matching type is valid valid_matching_types = ["one-to-one", "many-to-one", "one-to-many", "many-to-many"] super().__init__(valid_matching_types) def _compute( self, data: Matched, relax_skips_gt: bool = False, relax_skips_pred: bool = False ) -> dict[str, float]: gt_lengths = _get_lengths(data.gt_graph) pred_lengths = _get_lengths(data.pred_graph) cca = _get_cca(gt_lengths, pred_lengths) return {"CCA": cca}
def _get_lengths(track_graph: TrackingGraph) -> np.ndarray: """Identifies the length of complete cell cycles in a tracking graph Args: track_graph (TrackingGraph): The graph to evaluate Returns: np.ndarray[int]: an array of complete cell cycle lengths """ # Can't create a sparse graph from disconnected nodes if track_graph.graph.number_of_edges() == 0: return np.array([]) coords_array = np.asarray( [ # type: ignore [node_info[track_graph.frame_key], *[node_info[k] for k in track_graph.location_keys]] # type: ignore for _, node_info in track_graph.graph.nodes(data=True) ], dtype=np.float64, ) sparse_graph = nx.to_scipy_sparse_array(track_graph.graph, dtype=np.float64, format="coo") # type: ignore # build sparse array with frame spans of edges as weight # this ensures gap-closing edges have the right "length" i, j = sparse_graph.coords t = coords_array[:, 0] frame_span = np.abs(t[i] - t[j]) weighted_sparse_graph = coo_array((frame_span, (i, j)), shape=sparse_graph.shape).tocsr() csr_graph = weighted_sparse_graph + weighted_sparse_graph.T skan_graph = PathGraph.from_graph(adj=csr_graph, node_coordinates=coords_array) summary = summarize(skan_graph, separator="_") # branch_type 2 is junction to junction i.e. division to division division_to_division = summary[summary.branch_type == 2] cycle_lengths = division_to_division.branch_distance.values.astype(np.uint32) return cycle_lengths def _get_cca(gt_lengths: np.ndarray, pred_lengths: np.ndarray) -> float: """Compute CCA given two arrays of cell cycle lengths Args: gt_lengths (np.ndarray[int]): cell cycle lengths from the ground truth data pred_lengths (np.ndarray[int]): cell cycle lengths from the predicted data Returns: float: the cell cycle accuracy """ # GT and pred must both contain complete cell cycles to compute this metric if np.sum(gt_lengths) == 0 or np.sum(pred_lengths) == 0: warnings.warn( "GT and pred data do not both contain complete cell cycles. Returning CCA = 0", stacklevel=2, ) return np.nan n_bins = np.max([np.max(gt_lengths), np.max(pred_lengths)]) + 1 # Compute cumulative sum gt_cumsum = _get_cumsum(gt_lengths, n_bins) pred_cumsum = _get_cumsum(pred_lengths, n_bins) cca = 1 - np.max(np.abs(gt_cumsum - pred_cumsum)) return cca def _get_cumsum(lengths: np.ndarray, n_bins: int) -> np.ndarray: """Given an array of cell cycle lengths, computes cumulative sum from a normalized histogram of the lengths Args: lengths (np.ndarray[int]): an array of cell cycle lengths n_bins (int): number of bins for counting histogram Returns: np.ndarray: an array the cumulative sum of the normalized histogram """ # Compute track length histogram hist = np.bincount(lengths, minlength=n_bins) # Normalize hist = hist / hist.sum() # Compute cumsum cumsum = np.cumsum(hist) return cumsum