Source code for pytcl.trackers.multi_target

"""
Multi-target tracker implementation.

This module provides a multi-target tracker using GNN data association
and Kalman filtering with track management (initiation, maintenance, deletion).
"""

from enum import Enum
from typing import Callable, List, NamedTuple, Optional

import numpy as np
from numpy.typing import ArrayLike, NDArray

from pytcl.assignment_algorithms import chi2_gate_threshold, gnn_association


[docs] class TrackStatus(Enum): """Track status enumeration.""" TENTATIVE = "tentative" CONFIRMED = "confirmed" DELETED = "deleted"
[docs] class Track(NamedTuple): """ Multi-target track. Attributes ---------- id : int Unique track identifier. state : ndarray State estimate vector. covariance : ndarray State covariance matrix. status : TrackStatus Track status. hits : int Number of measurement updates. misses : int Number of consecutive missed detections. time : float Time of last update. """ id: int state: NDArray[np.float64] covariance: NDArray[np.float64] status: TrackStatus hits: int misses: int time: float
[docs] class MultiTargetTracker: """ Multi-target tracker with GNN data association. This tracker maintains multiple tracks and handles: - Track initiation from unassociated measurements - Track update via GNN data association - Track confirmation (M-of-N logic) - Track deletion (miss count) Parameters ---------- state_dim : int Dimension of state vector. meas_dim : int Dimension of measurement vector. F : callable or ndarray State transition matrix or function F(dt) -> ndarray. H : ndarray Measurement matrix. Q : callable or ndarray Process noise covariance or function Q(dt) -> ndarray. R : ndarray Measurement noise covariance. gate_probability : float, optional Gate probability for association (default: 0.99). confirm_hits : int, optional Hits needed to confirm track (default: 3). confirm_window : int, optional Window for M-of-N confirmation (default: 5). max_misses : int, optional Consecutive misses before deletion (default: 5). init_covariance : ndarray, optional Initial covariance for new tracks. If None, uses 100*R projected to state. Examples -------- >>> import numpy as np >>> # Constant velocity model >>> F = lambda dt: np.array([[1, dt, 0, 0], ... [0, 1, 0, 0], ... [0, 0, 1, dt], ... [0, 0, 0, 1]]) >>> H = np.array([[1, 0, 0, 0], ... [0, 0, 1, 0]]) >>> Q = lambda dt: 0.1 * np.eye(4) >>> R = np.eye(2) * 0.5 >>> tracker = MultiTargetTracker(4, 2, F, H, Q, R) >>> # Process measurements >>> measurements = [np.array([1, 2]), np.array([5, 6])] >>> tracks = tracker.process(measurements, dt=1.0) """
[docs] def __init__( self, state_dim: int, meas_dim: int, F: Callable[[float], NDArray[np.float64]] | NDArray[np.float64], H: NDArray[np.float64], Q: Callable[[float], NDArray[np.float64]] | NDArray[np.float64], R: NDArray[np.float64], gate_probability: float = 0.99, confirm_hits: int = 3, confirm_window: int = 5, max_misses: int = 5, init_covariance: Optional[NDArray[np.float64]] = None, ) -> None: self.state_dim = state_dim self.meas_dim = meas_dim self._F = F if callable(F) else lambda dt: F self.H = np.asarray(H, dtype=np.float64) self._Q = Q if callable(Q) else lambda dt: Q self.R = np.asarray(R, dtype=np.float64) self.gate_threshold = chi2_gate_threshold(gate_probability, meas_dim) self.confirm_hits = confirm_hits self.confirm_window = confirm_window self.max_misses = max_misses if init_covariance is not None: self.init_covariance = np.asarray(init_covariance, dtype=np.float64) else: # Default: large uncertainty self.init_covariance = np.eye(state_dim) * 100.0 # Track storage self._tracks: List[_InternalTrack] = [] self._next_id: int = 0 self._time: float = 0.0
@property def tracks(self) -> List[Track]: """Get list of active tracks.""" return [t.to_track() for t in self._tracks if t.status != TrackStatus.DELETED] @property def confirmed_tracks(self) -> List[Track]: """Get list of confirmed tracks only.""" return [t.to_track() for t in self._tracks if t.status == TrackStatus.CONFIRMED]
[docs] def process( self, measurements: List[ArrayLike], dt: float, ) -> List[Track]: """ Process measurements at new time step. Parameters ---------- measurements : list of array_like List of measurement vectors. dt : float Time step since last update. Returns ------- list of Track Active tracks after update. """ self._time += dt # Predict all tracks self._predict_all(dt) # Convert measurements to array if len(measurements) == 0: Z = np.zeros((0, self.meas_dim)) else: Z = np.array([np.asarray(m) for m in measurements]) # Data association if len(self._tracks) > 0 and len(measurements) > 0: associations = self._associate(Z) else: associations = {} # Update associated tracks associated_meas = set() for track_idx, meas_idx in associations.items(): self._update_track(track_idx, Z[meas_idx]) associated_meas.add(meas_idx) # Handle missed tracks for i, track in enumerate(self._tracks): if i not in associations and track.status != TrackStatus.DELETED: track.misses += 1 if track.misses >= self.max_misses: track.status = TrackStatus.DELETED # Initiate new tracks from unassociated measurements for j in range(len(measurements)): if j not in associated_meas: self._initiate_track(Z[j]) # Remove deleted tracks self._tracks = [t for t in self._tracks if t.status != TrackStatus.DELETED] return self.tracks
def _predict_all(self, dt: float) -> None: """Predict all tracks.""" F = self._F(dt) Q = self._Q(dt) for track in self._tracks: if track.status != TrackStatus.DELETED: track.state = F @ track.state track.covariance = F @ track.covariance @ F.T + Q track.time = self._time def _associate(self, Z: NDArray[np.float64]) -> dict[int, int]: """ Associate measurements to tracks using GNN. Returns dict mapping track_idx -> meas_idx. """ n_tracks = len(self._tracks) n_meas = Z.shape[0] # Build cost matrix cost_matrix = np.full((n_tracks, n_meas), np.inf) for i, track in enumerate(self._tracks): if track.status == TrackStatus.DELETED: continue z_pred = self.H @ track.state S = self.H @ track.covariance @ self.H.T + self.R S_inv = np.linalg.inv(S) for j in range(n_meas): innovation = Z[j] - z_pred d2 = float(innovation @ S_inv @ innovation) cost_matrix[i, j] = d2 # Run GNN result = gnn_association( cost_matrix, gate_threshold=self.gate_threshold, cost_of_non_assignment=self.gate_threshold, ) # Build association dict associations = {} for i in range(n_tracks): meas_idx = result.track_to_measurement[i] if meas_idx >= 0: associations[i] = meas_idx return associations def _update_track(self, track_idx: int, measurement: NDArray[np.float64]) -> None: """Update a single track with measurement.""" track = self._tracks[track_idx] # Innovation z_pred = self.H @ track.state innovation = measurement - z_pred S = self.H @ track.covariance @ self.H.T + self.R # Kalman gain K = track.covariance @ self.H.T @ np.linalg.inv(S) # Update track.state = track.state + K @ innovation track.covariance = (np.eye(self.state_dim) - K @ self.H) @ track.covariance # Update counts track.hits += 1 track.misses = 0 # Check confirmation if track.status == TrackStatus.TENTATIVE: if track.hits >= self.confirm_hits: track.status = TrackStatus.CONFIRMED def _initiate_track(self, measurement: NDArray[np.float64]) -> None: """Initiate new track from measurement.""" # Initialize state from measurement # Use pseudoinverse of H to map measurement to state H_pinv = np.linalg.pinv(self.H) state = H_pinv @ measurement # Create track track = _InternalTrack( id=self._next_id, state=state, covariance=self.init_covariance.copy(), status=TrackStatus.TENTATIVE, hits=1, misses=0, time=self._time, ) self._tracks.append(track) self._next_id += 1
class _InternalTrack: """Internal mutable track representation.""" def __init__( self, id: int, state: NDArray[np.float64], covariance: NDArray[np.float64], status: TrackStatus, hits: int, misses: int, time: float, ) -> None: self.id = id self.state = state self.covariance = covariance self.status = status self.hits = hits self.misses = misses self.time = time def to_track(self) -> Track: """Convert to immutable Track.""" return Track( id=self.id, state=self.state.copy(), covariance=self.covariance.copy(), status=self.status, hits=self.hits, misses=self.misses, time=self.time, ) __all__ = ["MultiTargetTracker", "Track", "TrackStatus"]