Multi-Target Tracking
This example demonstrates GNN-based multi-target tracking with track management.
Overview
Multi-target tracking (MTT) addresses:
Data association: Matching measurements to tracks
Track initiation: Detecting new targets
Track maintenance: Updating confirmed tracks
Track termination: Removing lost targets
Key Concepts
Global Nearest Neighbor (GNN): Optimal measurement-to-track assignment
Gating: Reducing assignment candidates using statistical tests
Track scoring: M/N logic and likelihood-based confirmation
Clutter modeling: False alarm rate estimation
Data Association: The Hungarian algorithm finds the optimal measurement-to-track assignment by minimizing total cost.
Performance Metrics: OSPA (Optimal Sub-Pattern Assignment) measures tracking accuracy including localization error, cardinality error, and labeling error.
Code Highlights
The example demonstrates:
Track initialization from unassigned measurements
GNN assignment using Hungarian algorithm
Kalman filter updates for each track
Track state machine (tentative, confirmed, deleted)
OSPA metric computation for performance evaluation
Source Code
1"""
2Multi-target tracking example.
3
4This example demonstrates:
51. Simulating multiple crossing targets
62. Using the MultiTargetTracker for GNN-based tracking
73. Track initiation, confirmation, and deletion
8
9Run with: python examples/multi_target_tracking.py
10"""
11
12# Add parent directory to path for development
13import sys
14from pathlib import Path
15
16sys.path.insert(0, str(Path(__file__).parent.parent))
17
18# Output directory for generated plots
19OUTPUT_DIR = Path(__file__).parent.parent / "docs" / "_static" / "images" / "examples"
20OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
21
22from typing import List, Tuple # noqa: E402
23
24import numpy as np # noqa: E402
25import plotly.graph_objects as go # noqa: E402
26
27from pytcl.trackers import ( # noqa: E402
28 MultiTargetTracker,
29 TrackStatus,
30)
31
32
33def simulate_targets(
34 n_steps: int = 50,
35 dt: float = 1.0,
36) -> Tuple[List[np.ndarray], List[List[np.ndarray]]]:
37 """
38 Simulate two crossing targets with position measurements.
39
40 Returns
41 -------
42 true_states : list of ndarray
43 Ground truth states [x1, y1, x2, y2] at each step.
44 measurements : list of list of ndarray
45 Noisy position measurements at each step.
46 """
47 # Target 1: Moving right and up
48 x1_0, y1_0 = 0.0, 0.0
49 vx1, vy1 = 2.0, 1.0
50
51 # Target 2: Moving left and up
52 x2_0, y2_0 = 100.0, 0.0
53 vx2, vy2 = -2.0, 1.5
54
55 true_states = []
56 measurements = []
57 R = np.eye(2) * 2.0 # Measurement noise covariance
58
59 for k in range(n_steps):
60 t = k * dt
61
62 # True positions
63 x1 = x1_0 + vx1 * t
64 y1 = y1_0 + vy1 * t
65 x2 = x2_0 + vx2 * t
66 y2 = y2_0 + vy2 * t
67
68 true_states.append(np.array([x1, y1, x2, y2]))
69
70 # Generate noisy measurements
71 meas = []
72
73 # Detection probability
74 pd = 0.95
75
76 if np.random.rand() < pd:
77 z1 = np.array([x1, y1]) + np.random.multivariate_normal([0, 0], R)
78 meas.append(z1)
79
80 if np.random.rand() < pd:
81 z2 = np.array([x2, y2]) + np.random.multivariate_normal([0, 0], R)
82 meas.append(z2)
83
84 # Add occasional false alarms
85 if np.random.rand() < 0.1:
86 # Random false alarm in scene
87 fa = np.array([np.random.uniform(-10, 110), np.random.uniform(-10, 60)])
88 meas.append(fa)
89
90 measurements.append(meas)
91
92 return true_states, measurements
93
94
95def run_tracker(
96 measurements: List[List[np.ndarray]],
97 dt: float = 1.0,
98) -> List[List]:
99 """
100 Run multi-target tracker on measurements.
101
102 Returns list of track histories at each step.
103 """
104
105 # Constant velocity model: state = [x, vx, y, vy]
106 def F(dt):
107 return np.array(
108 [[1, dt, 0, 0], [0, 1, 0, 0], [0, 0, 1, dt], [0, 0, 0, 1]], dtype=np.float64
109 )
110
111 # Measurement model: measure x and y
112 H = np.array([[1, 0, 0, 0], [0, 0, 1, 0]], dtype=np.float64)
113
114 # Process noise (acceleration noise)
115 def Q(dt):
116 q = 0.5 # Acceleration noise std
117 return (
118 np.array(
119 [
120 [dt**4 / 4, dt**3 / 2, 0, 0],
121 [dt**3 / 2, dt**2, 0, 0],
122 [0, 0, dt**4 / 4, dt**3 / 2],
123 [0, 0, dt**3 / 2, dt**2],
124 ]
125 )
126 * q**2
127 )
128
129 # Measurement noise
130 R = np.eye(2) * 2.0
131
132 # Initial covariance for new tracks
133 P0 = np.diag([10.0, 5.0, 10.0, 5.0])
134
135 # Create tracker
136 tracker = MultiTargetTracker(
137 state_dim=4,
138 meas_dim=2,
139 F=F,
140 H=H,
141 Q=Q,
142 R=R,
143 gate_probability=0.99,
144 confirm_hits=3,
145 max_misses=5,
146 init_covariance=P0,
147 )
148
149 # Process all measurements
150 track_history = []
151
152 for meas in measurements:
153 tracks = tracker.process(meas, dt)
154 track_history.append(tracks)
155
156 return track_history
157
158
159def plot_results(
160 true_states: List[np.ndarray],
161 measurements: List[List[np.ndarray]],
162 track_history: List[List],
163) -> None:
164 """Plot tracking results."""
165 fig = go.Figure()
166
167 # Plot true trajectories
168 true_arr = np.array(true_states)
169 fig.add_trace(
170 go.Scatter(
171 x=true_arr[:, 0],
172 y=true_arr[:, 1],
173 mode="lines",
174 line=dict(color="green", width=2),
175 name="Target 1 (truth)",
176 )
177 )
178 fig.add_trace(
179 go.Scatter(
180 x=true_arr[:, 2],
181 y=true_arr[:, 3],
182 mode="lines",
183 line=dict(color="blue", width=2),
184 name="Target 2 (truth)",
185 )
186 )
187
188 # Collect all measurements for a single trace
189 meas_x = []
190 meas_y = []
191 for meas in measurements:
192 for z in meas:
193 meas_x.append(z[0])
194 meas_y.append(z[1])
195
196 fig.add_trace(
197 go.Scatter(
198 x=meas_x,
199 y=meas_y,
200 mode="markers",
201 marker=dict(color="black", size=3, opacity=0.5),
202 name="Measurements",
203 )
204 )
205
206 # Plot tracks
207 # Collect track positions by track ID
208 track_positions: dict[int, list] = {}
209 for tracks in track_history:
210 for track in tracks:
211 if track.status == TrackStatus.CONFIRMED:
212 if track.id not in track_positions:
213 track_positions[track.id] = []
214 track_positions[track.id].append(
215 (track.state[0], track.state[2])
216 ) # x, y
217
218 # Plotly color palette (similar to tab10)
219 colors = [
220 "#1f77b4",
221 "#ff7f0e",
222 "#2ca02c",
223 "#d62728",
224 "#9467bd",
225 "#8c564b",
226 "#e377c2",
227 "#7f7f7f",
228 "#bcbd22",
229 "#17becf",
230 ]
231
232 # Plot each track
233 for i, (track_id, positions) in enumerate(track_positions.items()):
234 if len(positions) > 1:
235 pos_arr = np.array(positions)
236 fig.add_trace(
237 go.Scatter(
238 x=pos_arr[:, 0],
239 y=pos_arr[:, 1],
240 mode="lines+markers",
241 line=dict(color=colors[i % 10], width=1.5),
242 marker=dict(color=colors[i % 10], size=4),
243 name=f"Track {track_id}",
244 )
245 )
246
247 fig.update_layout(
248 title="Multi-Target Tracking with GNN Data Association",
249 xaxis_title="X Position",
250 yaxis_title="Y Position",
251 xaxis=dict(scaleanchor="y", scaleratio=1),
252 width=1200,
253 height=800,
254 showlegend=True,
255 )
256
257 # Save as HTML (interactive) and PNG (static)
258 output_path = OUTPUT_DIR / "multi_target_tracking_result.html"
259 fig.write_html(str(output_path))
260 print(f"Interactive plot saved to {output_path}")
261 fig.show()
262
263
264def main():
265 """Run multi-target tracking example."""
266 print("Multi-Target Tracking Example")
267 print("=" * 50)
268
269 # Set random seed for reproducibility
270 np.random.seed(42)
271
272 # Simulate targets
273 print("Simulating two crossing targets...")
274 true_states, measurements = simulate_targets(n_steps=50, dt=1.0)
275 print(f" Generated {len(true_states)} time steps")
276 print(f" Total measurements: {sum(len(m) for m in measurements)}")
277
278 # Run tracker
279 print("\nRunning multi-target tracker...")
280 track_history = run_tracker(measurements, dt=1.0)
281
282 # Count tracks
283 all_tracks = set()
284 confirmed_tracks = set()
285 for tracks in track_history:
286 for track in tracks:
287 all_tracks.add(track.id)
288 if track.status == TrackStatus.CONFIRMED:
289 confirmed_tracks.add(track.id)
290
291 print(f" Total tracks initiated: {len(all_tracks)}")
292 print(f" Confirmed tracks: {len(confirmed_tracks)}")
293
294 # Final track summary
295 final_tracks = track_history[-1]
296 print(f"\nFinal active tracks: {len(final_tracks)}")
297 for track in final_tracks:
298 pos = (track.state[0], track.state[2])
299 vel = (track.state[1], track.state[3])
300 print(
301 f" Track {track.id}: pos=({pos[0]:.1f}, {pos[1]:.1f}), "
302 f"vel=({vel[0]:.1f}, {vel[1]:.1f}), status={track.status.value}"
303 )
304
305 # Plot if plotly is available
306 try:
307 plot_results(true_states, measurements, track_history)
308 except Exception as e:
309 print(f"\nCould not generate plot: {e}")
310
311 print("\nDone!")
312
313
314if __name__ == "__main__":
315 main()
Running the Example
python examples/multi_target_tracking.py
See Also
Assignment Algorithms - Assignment algorithm details
Performance Evaluation - OSPA and tracking metrics
3D Target Tracking - 3D target tracking
Tracking Containers - Track and measurement data structures