Source code for traccuracy.loaders._point

from typing import Optional, cast

import networkx as nx
import numpy as np
import pandas as pd

from traccuracy._tracking_graph import TrackingGraph


[docs] def load_point_data( path: str | None = None, df: Optional[pd.DataFrame] = None, parent_column: str = "parent", id_column: str = "node_id", pos_columns: tuple[str, ...] = ("z", "y", "x"), time_column: str = "t", seg_id_column: str | None = None, name: str | None = None, sep: str | None = None, ) -> TrackingGraph: """Load point-based tracking data into a TrackingGraph from a csv-like file Assumes each row contains: - time - position, e.g. three columns 'z', 'y', 'x' - parent, a reference to the node in the previous time frame. A node without a parent can be indicated by -1 Args: path (str | None, optional): Path to the csv-like file to load. Defaults to None. df (pd.DataFrame | None, optional): A dataframe that has already been loaded. Defaults to None. parent_column (str | None, optional): A reference to the parent node in the previous time frame. Defaults to "parent". id_column (str, optional): Column used to specify node ids. Node IDs should be unique positive integers. Defaults to 'node_id' pos_columns (tuple[str], optional): A tuple of columns to use for position. Defaults to ("z", "y", "x"). time_column (str, optional): The column to use for time. Defaults to "t". seg_id_column (str | None, optional): Name of an optional column containing a segmentation label id. Defaults to None. name (str | None, optional): Optional string to name/describe the dataset. Defaults to None. sep (str | None, optional): Passed to pd.read_csv to set the sep kwarg. Defaults to None. Raises: ValueError: Must provide either a path or a dataframe ValueError: parent_column not present in data ValueError: id_column not present in data ValueError: id_column does not contain positive integers ValueError: id_column does not contain unique values ValueError: pos_columns not present in data ValueError: time_column not present in data ValueError: seg_id_column not present in data Returns: TrackingGraph """ if path is None and df is None: raise ValueError("Must provide either a path or a dataframe") if path: if sep is None: df = pd.read_csv(path) else: df = pd.read_csv(path, sep=sep) # At this point, df is guaranteed to be dataframe not None df = cast("pd.DataFrame", df) if parent_column not in df.columns: raise ValueError(f"Specified parent_column {parent_column} not present") if id_column not in df.columns: raise ValueError(f"Specified id_column {id_column} not present") if not pd.api.types.is_integer_dtype(df[id_column]) or not np.all(df[id_column] >= 0): raise ValueError(f"Specified id_column {id_column} must contain positive integers.") if not len(df[id_column].unique()) == len(df[id_column]): raise ValueError(f"Specified id_column {id_column} must contain unique values.") if not all(c in df.columns for c in pos_columns): raise ValueError(f"Specified pos_columns {pos_columns} not present") if time_column not in df.columns: raise ValueError(f"Specified time_column {time_column} not present") if seg_id_column and seg_id_column not in df.columns: raise ValueError(f"Specified seg_id_column {seg_id_column} not present") node_attr_cols = [time_column, *pos_columns] if seg_id_column: node_attr_cols.append(seg_id_column) nodes, edges = [], [] for _, row in df.iterrows(): node_id = row[id_column] nodes.append((node_id, row[node_attr_cols].to_dict())) if row[parent_column] != -1: edges.append((row[parent_column], node_id)) G: nx.DiGraph = nx.DiGraph() G.add_nodes_from(nodes) G.add_edges_from(edges) if seg_id_column: return TrackingGraph( G, frame_key=time_column, location_keys=pos_columns, label_key=seg_id_column, name=name ) return TrackingGraph(G, frame_key=time_column, location_keys=pos_columns, name=name)