Source code for pytcl.containers.track_list

"""
Track list container.

This module provides a collection class for managing multiple tracks
with filtering, querying, and batch operations.
"""

from __future__ import annotations

from typing import (
    TYPE_CHECKING,
    Callable,
    Iterable,
    Iterator,
    List,
    NamedTuple,
    Optional,
    Tuple,
    Union,
)

import numpy as np
from numpy.typing import ArrayLike, NDArray

from pytcl.trackers.multi_target import Track, TrackStatus

if TYPE_CHECKING:
    from pytcl.trackers.multi_target import MultiTargetTracker


[docs] class TrackQuery(NamedTuple): """ Result of a track list query. Attributes ---------- tracks : List[Track] List of tracks matching the query. indices : List[int] Original indices of the matching tracks. """ tracks: List[Track] indices: List[int]
[docs] class TrackListStats(NamedTuple): """ Statistics about a track list. Attributes ---------- n_tracks : int Total number of tracks. n_confirmed : int Number of confirmed tracks. n_tentative : int Number of tentative tracks. n_deleted : int Number of deleted tracks. mean_hits : float Average number of hits per track. mean_misses : float Average number of misses per track. """ n_tracks: int n_confirmed: int n_tentative: int n_deleted: int mean_hits: float mean_misses: float
[docs] class TrackList: """ Collection of tracks with query and filter capabilities. Provides: - Filtering by status, time, region - Iteration and indexing - Batch state extraction - Statistics computation Parameters ---------- tracks : Iterable[Track], optional Initial tracks to add to the list. Examples -------- >>> import numpy as np >>> from pytcl.trackers.multi_target import Track, TrackStatus >>> # Create some tracks >>> t1 = Track(id=0, state=np.array([1, 0, 2, 0]), ... covariance=np.eye(4), status=TrackStatus.CONFIRMED, ... hits=5, misses=0, time=1.0) >>> t2 = Track(id=1, state=np.array([5, 0, 6, 0]), ... covariance=np.eye(4), status=TrackStatus.TENTATIVE, ... hits=2, misses=1, time=1.0) >>> # Create track list >>> tracks = TrackList([t1, t2]) >>> len(tracks) 2 >>> # Filter by status >>> confirmed = tracks.confirmed >>> len(confirmed) 1 >>> # Get positions >>> positions = tracks.positions(indices=(0, 2)) >>> positions.shape (2, 2) """
[docs] def __init__(self, tracks: Optional[Iterable[Track]] = None) -> None: """Initialize track list.""" if tracks is None: self._tracks: List[Track] = [] else: self._tracks = list(tracks) # Build ID lookup for fast access self._id_to_idx: dict[int, int] = {t.id: i for i, t in enumerate(self._tracks)}
[docs] @classmethod def from_tracker(cls, tracker: MultiTargetTracker) -> TrackList: """ Create a TrackList from a MultiTargetTracker. Parameters ---------- tracker : MultiTargetTracker The tracker to extract tracks from. Returns ------- TrackList A new TrackList containing the tracker's current tracks. """ return cls(tracker.tracks)
[docs] def __len__(self) -> int: """Return number of tracks.""" return len(self._tracks)
[docs] def __iter__(self) -> Iterator[Track]: """Iterate over tracks.""" return iter(self._tracks)
[docs] def __getitem__(self, idx: Union[int, slice]) -> Union[Track, TrackList]: """ Get track by index or slice. Parameters ---------- idx : int or slice Index or slice to retrieve. Returns ------- Track or TrackList Single track if int, TrackList if slice. """ if isinstance(idx, int): return self._tracks[idx] else: return TrackList(self._tracks[idx])
[docs] def __contains__(self, track_id: int) -> bool: """Check if track ID exists in list.""" return track_id in self._id_to_idx
[docs] def __repr__(self) -> str: """String representation.""" return f"TrackList(n_tracks={len(self)})"
[docs] def get_by_id(self, track_id: int) -> Optional[Track]: """ Get track by ID. Parameters ---------- track_id : int Track ID to find. Returns ------- Track or None The track if found, None otherwise. """ idx = self._id_to_idx.get(track_id) if idx is not None: return self._tracks[idx] return None
[docs] def get_by_ids(self, track_ids: Iterable[int]) -> TrackList: """ Get multiple tracks by their IDs. Parameters ---------- track_ids : Iterable[int] Track IDs to find. Returns ------- TrackList New TrackList containing the found tracks. """ tracks = [] for tid in track_ids: track = self.get_by_id(tid) if track is not None: tracks.append(track) return TrackList(tracks)
[docs] def filter_by_status(self, status: TrackStatus) -> TrackList: """ Filter tracks by status. Parameters ---------- status : TrackStatus Status to filter by. Returns ------- TrackList New TrackList containing only tracks with the given status. """ return TrackList([t for t in self._tracks if t.status == status])
[docs] def filter_by_time( self, min_time: Optional[float] = None, max_time: Optional[float] = None, ) -> TrackList: """ Filter tracks by time range. Parameters ---------- min_time : float, optional Minimum time (inclusive). max_time : float, optional Maximum time (inclusive). Returns ------- TrackList New TrackList containing tracks within the time range. """ tracks = [] for t in self._tracks: if min_time is not None and t.time < min_time: continue if max_time is not None and t.time > max_time: continue tracks.append(t) return TrackList(tracks)
[docs] def filter_by_region( self, center: ArrayLike, radius: float, state_indices: Tuple[int, int] = (0, 2), ) -> TrackList: """ Filter tracks by spatial region. Parameters ---------- center : array_like Center point [x, y]. radius : float Radius of the region. state_indices : tuple of int, optional Indices of x and y in state vector (default: (0, 2)). Returns ------- TrackList New TrackList containing tracks within the region. """ center = np.asarray(center, dtype=np.float64) ix, iy = state_indices tracks = [] for t in self._tracks: pos = np.array([t.state[ix], t.state[iy]]) dist = np.linalg.norm(pos - center) if dist <= radius: tracks.append(t) return TrackList(tracks)
[docs] def filter_by_predicate(self, predicate: Callable[[Track], bool]) -> TrackList: """ Filter tracks using a custom predicate. Parameters ---------- predicate : callable Function that takes a Track and returns True to include it. Returns ------- TrackList New TrackList containing tracks that pass the predicate. """ return TrackList([t for t in self._tracks if predicate(t)])
@property def confirmed(self) -> TrackList: """Get confirmed tracks only.""" return self.filter_by_status(TrackStatus.CONFIRMED) @property def tentative(self) -> TrackList: """Get tentative tracks only.""" return self.filter_by_status(TrackStatus.TENTATIVE) @property def track_ids(self) -> List[int]: """Get list of all track IDs.""" return [t.id for t in self._tracks]
[docs] def states(self) -> NDArray[np.float64]: """ Extract all track states as array. Returns ------- ndarray Array of shape (n_tracks, state_dim) containing all states. """ if len(self._tracks) == 0: return np.zeros((0, 0)) return np.array([t.state for t in self._tracks])
[docs] def covariances(self) -> NDArray[np.float64]: """ Extract all track covariances as array. Returns ------- ndarray Array of shape (n_tracks, state_dim, state_dim). """ if len(self._tracks) == 0: return np.zeros((0, 0, 0)) return np.array([t.covariance for t in self._tracks])
[docs] def positions(self, indices: Tuple[int, int] = (0, 2)) -> NDArray[np.float64]: """ Extract track positions. Parameters ---------- indices : tuple of int, optional Indices of x and y in state vector (default: (0, 2)). Returns ------- ndarray Array of shape (n_tracks, 2) containing positions. """ if len(self._tracks) == 0: return np.zeros((0, 2)) ix, iy = indices return np.array([[t.state[ix], t.state[iy]] for t in self._tracks])
[docs] def stats(self) -> TrackListStats: """ Compute statistics about the track list. Returns ------- TrackListStats Statistics about the tracks. """ n_tracks = len(self._tracks) if n_tracks == 0: return TrackListStats( n_tracks=0, n_confirmed=0, n_tentative=0, n_deleted=0, mean_hits=0.0, mean_misses=0.0, ) n_confirmed = sum(1 for t in self._tracks if t.status == TrackStatus.CONFIRMED) n_tentative = sum(1 for t in self._tracks if t.status == TrackStatus.TENTATIVE) n_deleted = sum(1 for t in self._tracks if t.status == TrackStatus.DELETED) mean_hits = sum(t.hits for t in self._tracks) / n_tracks mean_misses = sum(t.misses for t in self._tracks) / n_tracks return TrackListStats( n_tracks=n_tracks, n_confirmed=n_confirmed, n_tentative=n_tentative, n_deleted=n_deleted, mean_hits=mean_hits, mean_misses=mean_misses, )
[docs] def add(self, track: Track) -> TrackList: """ Add a track and return a new TrackList. Parameters ---------- track : Track Track to add. Returns ------- TrackList New TrackList with the track added. """ return TrackList(self._tracks + [track])
[docs] def remove(self, track_id: int) -> TrackList: """ Remove a track by ID and return a new TrackList. Parameters ---------- track_id : int ID of track to remove. Returns ------- TrackList New TrackList without the specified track. """ return TrackList([t for t in self._tracks if t.id != track_id])
[docs] def merge(self, other: TrackList) -> TrackList: """ Merge with another TrackList. Parameters ---------- other : TrackList TrackList to merge with. Returns ------- TrackList New TrackList containing tracks from both lists. Notes ----- Duplicate track IDs are not handled specially - both tracks will be included. Use get_by_id on the result to access specific tracks. """ return TrackList(list(self._tracks) + list(other._tracks))
[docs] def copy(self) -> TrackList: """ Create a copy of this TrackList. Returns ------- TrackList A new TrackList with the same tracks. """ return TrackList(self._tracks)
__all__ = [ "TrackList", "TrackQuery", "TrackListStats", ]