Multi-Target Tracking

This example demonstrates GNN-based multi-target tracking with track management.

Overview

Multi-target tracking (MTT) addresses:

  • Data association: Matching measurements to tracks

  • Track initiation: Detecting new targets

  • Track maintenance: Updating confirmed tracks

  • Track termination: Removing lost targets

Key Concepts

  • Global Nearest Neighbor (GNN): Optimal measurement-to-track assignment

  • Gating: Reducing assignment candidates using statistical tests

  • Track scoring: M/N logic and likelihood-based confirmation

  • Clutter modeling: False alarm rate estimation

Data Association: The Hungarian algorithm finds the optimal measurement-to-track assignment by minimizing total cost.

Performance Metrics: OSPA (Optimal Sub-Pattern Assignment) measures tracking accuracy including localization error, cardinality error, and labeling error.

Code Highlights

The example demonstrates:

  • Track initialization from unassigned measurements

  • GNN assignment using Hungarian algorithm

  • Kalman filter updates for each track

  • Track state machine (tentative, confirmed, deleted)

  • OSPA metric computation for performance evaluation

Source Code

  1"""
  2Multi-target tracking example.
  3
  4This example demonstrates:
  51. Simulating multiple crossing targets
  62. Using the MultiTargetTracker for GNN-based tracking
  73. Track initiation, confirmation, and deletion
  8
  9Run with: python examples/multi_target_tracking.py
 10"""
 11
 12# Add parent directory to path for development
 13import sys
 14from pathlib import Path
 15
 16sys.path.insert(0, str(Path(__file__).parent.parent))
 17
 18# Output directory for generated plots
 19OUTPUT_DIR = Path(__file__).parent.parent / "docs" / "_static" / "images" / "examples"
 20OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
 21
 22from typing import List, Tuple  # noqa: E402
 23
 24import numpy as np  # noqa: E402
 25import plotly.graph_objects as go  # noqa: E402
 26
 27from pytcl.trackers import (  # noqa: E402
 28    MultiTargetTracker,
 29    TrackStatus,
 30)
 31
 32
 33def simulate_targets(
 34    n_steps: int = 50,
 35    dt: float = 1.0,
 36) -> Tuple[List[np.ndarray], List[List[np.ndarray]]]:
 37    """
 38    Simulate two crossing targets with position measurements.
 39
 40    Returns
 41    -------
 42    true_states : list of ndarray
 43        Ground truth states [x1, y1, x2, y2] at each step.
 44    measurements : list of list of ndarray
 45        Noisy position measurements at each step.
 46    """
 47    # Target 1: Moving right and up
 48    x1_0, y1_0 = 0.0, 0.0
 49    vx1, vy1 = 2.0, 1.0
 50
 51    # Target 2: Moving left and up
 52    x2_0, y2_0 = 100.0, 0.0
 53    vx2, vy2 = -2.0, 1.5
 54
 55    true_states = []
 56    measurements = []
 57    R = np.eye(2) * 2.0  # Measurement noise covariance
 58
 59    for k in range(n_steps):
 60        t = k * dt
 61
 62        # True positions
 63        x1 = x1_0 + vx1 * t
 64        y1 = y1_0 + vy1 * t
 65        x2 = x2_0 + vx2 * t
 66        y2 = y2_0 + vy2 * t
 67
 68        true_states.append(np.array([x1, y1, x2, y2]))
 69
 70        # Generate noisy measurements
 71        meas = []
 72
 73        # Detection probability
 74        pd = 0.95
 75
 76        if np.random.rand() < pd:
 77            z1 = np.array([x1, y1]) + np.random.multivariate_normal([0, 0], R)
 78            meas.append(z1)
 79
 80        if np.random.rand() < pd:
 81            z2 = np.array([x2, y2]) + np.random.multivariate_normal([0, 0], R)
 82            meas.append(z2)
 83
 84        # Add occasional false alarms
 85        if np.random.rand() < 0.1:
 86            # Random false alarm in scene
 87            fa = np.array([np.random.uniform(-10, 110), np.random.uniform(-10, 60)])
 88            meas.append(fa)
 89
 90        measurements.append(meas)
 91
 92    return true_states, measurements
 93
 94
 95def run_tracker(
 96    measurements: List[List[np.ndarray]],
 97    dt: float = 1.0,
 98) -> List[List]:
 99    """
100    Run multi-target tracker on measurements.
101
102    Returns list of track histories at each step.
103    """
104
105    # Constant velocity model: state = [x, vx, y, vy]
106    def F(dt):
107        return np.array(
108            [[1, dt, 0, 0], [0, 1, 0, 0], [0, 0, 1, dt], [0, 0, 0, 1]], dtype=np.float64
109        )
110
111    # Measurement model: measure x and y
112    H = np.array([[1, 0, 0, 0], [0, 0, 1, 0]], dtype=np.float64)
113
114    # Process noise (acceleration noise)
115    def Q(dt):
116        q = 0.5  # Acceleration noise std
117        return (
118            np.array(
119                [
120                    [dt**4 / 4, dt**3 / 2, 0, 0],
121                    [dt**3 / 2, dt**2, 0, 0],
122                    [0, 0, dt**4 / 4, dt**3 / 2],
123                    [0, 0, dt**3 / 2, dt**2],
124                ]
125            )
126            * q**2
127        )
128
129    # Measurement noise
130    R = np.eye(2) * 2.0
131
132    # Initial covariance for new tracks
133    P0 = np.diag([10.0, 5.0, 10.0, 5.0])
134
135    # Create tracker
136    tracker = MultiTargetTracker(
137        state_dim=4,
138        meas_dim=2,
139        F=F,
140        H=H,
141        Q=Q,
142        R=R,
143        gate_probability=0.99,
144        confirm_hits=3,
145        max_misses=5,
146        init_covariance=P0,
147    )
148
149    # Process all measurements
150    track_history = []
151
152    for meas in measurements:
153        tracks = tracker.process(meas, dt)
154        track_history.append(tracks)
155
156    return track_history
157
158
159def plot_results(
160    true_states: List[np.ndarray],
161    measurements: List[List[np.ndarray]],
162    track_history: List[List],
163) -> None:
164    """Plot tracking results."""
165    fig = go.Figure()
166
167    # Plot true trajectories
168    true_arr = np.array(true_states)
169    fig.add_trace(
170        go.Scatter(
171            x=true_arr[:, 0],
172            y=true_arr[:, 1],
173            mode="lines",
174            line=dict(color="green", width=2),
175            name="Target 1 (truth)",
176        )
177    )
178    fig.add_trace(
179        go.Scatter(
180            x=true_arr[:, 2],
181            y=true_arr[:, 3],
182            mode="lines",
183            line=dict(color="blue", width=2),
184            name="Target 2 (truth)",
185        )
186    )
187
188    # Collect all measurements for a single trace
189    meas_x = []
190    meas_y = []
191    for meas in measurements:
192        for z in meas:
193            meas_x.append(z[0])
194            meas_y.append(z[1])
195
196    fig.add_trace(
197        go.Scatter(
198            x=meas_x,
199            y=meas_y,
200            mode="markers",
201            marker=dict(color="black", size=3, opacity=0.5),
202            name="Measurements",
203        )
204    )
205
206    # Plot tracks
207    # Collect track positions by track ID
208    track_positions: dict[int, list] = {}
209    for tracks in track_history:
210        for track in tracks:
211            if track.status == TrackStatus.CONFIRMED:
212                if track.id not in track_positions:
213                    track_positions[track.id] = []
214                track_positions[track.id].append(
215                    (track.state[0], track.state[2])
216                )  # x, y
217
218    # Plotly color palette (similar to tab10)
219    colors = [
220        "#1f77b4",
221        "#ff7f0e",
222        "#2ca02c",
223        "#d62728",
224        "#9467bd",
225        "#8c564b",
226        "#e377c2",
227        "#7f7f7f",
228        "#bcbd22",
229        "#17becf",
230    ]
231
232    # Plot each track
233    for i, (track_id, positions) in enumerate(track_positions.items()):
234        if len(positions) > 1:
235            pos_arr = np.array(positions)
236            fig.add_trace(
237                go.Scatter(
238                    x=pos_arr[:, 0],
239                    y=pos_arr[:, 1],
240                    mode="lines+markers",
241                    line=dict(color=colors[i % 10], width=1.5),
242                    marker=dict(color=colors[i % 10], size=4),
243                    name=f"Track {track_id}",
244                )
245            )
246
247    fig.update_layout(
248        title="Multi-Target Tracking with GNN Data Association",
249        xaxis_title="X Position",
250        yaxis_title="Y Position",
251        xaxis=dict(scaleanchor="y", scaleratio=1),
252        width=1200,
253        height=800,
254        showlegend=True,
255    )
256
257    # Save as HTML (interactive) and PNG (static)
258    output_path = OUTPUT_DIR / "multi_target_tracking_result.html"
259    fig.write_html(str(output_path))
260    print(f"Interactive plot saved to {output_path}")
261    fig.show()
262
263
264def main():
265    """Run multi-target tracking example."""
266    print("Multi-Target Tracking Example")
267    print("=" * 50)
268
269    # Set random seed for reproducibility
270    np.random.seed(42)
271
272    # Simulate targets
273    print("Simulating two crossing targets...")
274    true_states, measurements = simulate_targets(n_steps=50, dt=1.0)
275    print(f"  Generated {len(true_states)} time steps")
276    print(f"  Total measurements: {sum(len(m) for m in measurements)}")
277
278    # Run tracker
279    print("\nRunning multi-target tracker...")
280    track_history = run_tracker(measurements, dt=1.0)
281
282    # Count tracks
283    all_tracks = set()
284    confirmed_tracks = set()
285    for tracks in track_history:
286        for track in tracks:
287            all_tracks.add(track.id)
288            if track.status == TrackStatus.CONFIRMED:
289                confirmed_tracks.add(track.id)
290
291    print(f"  Total tracks initiated: {len(all_tracks)}")
292    print(f"  Confirmed tracks: {len(confirmed_tracks)}")
293
294    # Final track summary
295    final_tracks = track_history[-1]
296    print(f"\nFinal active tracks: {len(final_tracks)}")
297    for track in final_tracks:
298        pos = (track.state[0], track.state[2])
299        vel = (track.state[1], track.state[3])
300        print(
301            f"  Track {track.id}: pos=({pos[0]:.1f}, {pos[1]:.1f}), "
302            f"vel=({vel[0]:.1f}, {vel[1]:.1f}), status={track.status.value}"
303        )
304
305    # Plot if plotly is available
306    try:
307        plot_results(true_states, measurements, track_history)
308    except Exception as e:
309        print(f"\nCould not generate plot: {e}")
310
311    print("\nDone!")
312
313
314if __name__ == "__main__":
315    main()

Running the Example

python examples/multi_target_tracking.py

See Also