Source code for traccuracy.loaders._geff

import os
from typing import TYPE_CHECKING, cast

import numpy as np
import zarr
from geff import GeffMetadata, read

from traccuracy._tracking_graph import TrackingGraph

if TYPE_CHECKING:
    from collections.abc import Hashable

    import networkx as nx


[docs] def load_geff_data( geff_path: str, load_geff_seg: bool = False, seg_path: str | None = None, seg_property: str | None = None, name: str | None = None, load_all_props: bool = False, border_margin: float | None = None, ) -> TrackingGraph: """Load a graph into memory from a geff file Segmentations can be optionally loaded either from a related object specified in the geff (`load_geff_seg=True`) or with a path to a zarr array `seg_path` and `seg_property`. If loading graphs with flags, e.g. for visualization, pass `load_all_props=True`. Args: geff_path (str): Path to a geff group inside of a zarr, load_geff_seg (bool, optional): Load segmentation based on a geff metadata of related segmentation. Defaults to False. seg_path (str | None, optional): Path to a zarr array containing segmentation data. We assume that the axes order in your segmentation array matches the axes in your geff. If this is not true please load the segmentation yourself and add it to TrackingGraph.segmentation. Defaults to None. seg_property (str | None, optional): If seg_path provided, this is the corresponding property on the geff graph that contains the segmentation key. Defaults to None. name (str | None, optional): Optional name to store on TrackingGraph for identification. Defaults to None. load_all_props (bool, optional): If True, load all node and edge properties on the graph. Defaults to False and only spatiotemporal and segmentation node properties are loaded. Set to True to get already annotated error flags, e.g. for visualization. border_margin (float, optional): If set, nodes whose centroid is within this distance (in pixels) of the spatial border will be excluded from the graph. Requires segmentation to be loaded. Defaults to None (no filtering). """ if load_geff_seg and seg_path is not None: raise ValueError('Please specify either load_geff_seg=True or seg_path="path/to/seg.zarr"') if seg_path is not None and seg_property is None: raise ValueError( "If seg_path is specified, a corresponding seg_property must be specified to link " "segmentations to a segmentation label property on the graph" ) meta = GeffMetadata.read(geff_path) if meta.directed is False: raise ValueError( f"traccuracy only supports directed graphs. Found undirected graph at {geff_path}" ) # Collect names of axes so that we only load spatial properties spatial_props = [] temporal_prop = None if meta.axes is None: raise ValueError("No spatial or temporal axes were found in the input geff") for ax in meta.axes: if ax.type == "time": temporal_prop = ax.name elif ax.type == "space": spatial_props.append(ax.name) if temporal_prop is None: raise ValueError("A required time property was not found in the axes of the input geff") if len(spatial_props) == 0: raise ValueError("Required spatial axes were not found in the axes of the input geff") load_props = [*spatial_props, temporal_prop] segmentation = None label_key = None # Load segmentation from related objects if load_geff_seg: if meta.related_objects is None: raise ValueError("Did not find related_objects in geff") # Look for labels in related objects rel_obj_path = None for rel_obj in meta.related_objects: if rel_obj.type == "labels": rel_obj_path = os.path.join(geff_path, rel_obj.path) label_key = rel_obj.label_prop if rel_obj_path is None: raise ValueError('Did not find related_object of type "labels" in geff related objects') else: load_props.append(label_key) # type: ignore segmentation = np.asarray(zarr.open_array(rel_obj_path)[:]) # Load segmentation from stand alone zarr if seg_path is not None: segmentation = np.asarray(zarr.open_array(seg_path)[:]) load_props.append(seg_property) # type: ignore label_key = seg_property # Check dimensionality of segmentation if loaded if segmentation is not None and len(segmentation.shape) != 1 + len(spatial_props): raise ValueError( f"Expected dimensionality of segmentation data {1 + len(spatial_props)}D " f"does not match shape {segmentation.shape}" ) if load_all_props: G, _ = read(geff_path, backend="networkx") else: G, _ = read(geff_path, node_props=load_props, edge_props=[], backend="networkx") # We checked earlier that the graph is directed G = cast("nx.DiGraph[Hashable]", G) return TrackingGraph( graph=G, segmentation=segmentation, label_key=label_key, frame_key=temporal_prop, location_keys=tuple(spatial_props), name=name, border_margin=border_margin, )