Particle Filters

This example demonstrates bootstrap particle filters with various resampling methods.

Overview

Particle filters (Sequential Monte Carlo) handle:

  • Nonlinear dynamics - Arbitrary state transition functions

  • Non-Gaussian noise - Any noise distribution

  • Multi-modal posteriors - Multiple hypotheses

Key Concepts

  • Importance sampling: Weighting particles by likelihood

  • Resampling: Eliminating low-weight particles

  • Effective sample size: Measuring particle degeneracy

  • Roughening: Preventing sample impoverishment

Resampling Methods

The example compares different resampling strategies:

  1. Multinomial - Standard random resampling

  2. Systematic - Evenly spaced samples on CDF

  3. Stratified - Stratified random sampling

  4. Residual - Deterministic + random resampling

Code Highlights

The example demonstrates:

  • Bootstrap particle filter initialization

  • Weight computation from likelihoods

  • Different resampling implementations

  • Effective sample size monitoring

  • State estimation from weighted particles

Source Code

  1"""
  2Particle Filters Example.
  3
  4This example demonstrates particle filtering (Sequential Monte Carlo)
  5algorithms in PyTCL:
  6
  7- Bootstrap particle filter
  8- Importance sampling and resampling
  9- Different resampling strategies (multinomial, systematic, residual)
 10- Effective sample size monitoring
 11- Particle statistics computation
 12- Comparison with Kalman filter for linear systems
 13- Nonlinear system tracking
 14
 15Particle filters are essential for nonlinear, non-Gaussian state estimation
 16where Kalman filters cannot be directly applied.
 17
 18Run with: python examples/particle_filters.py
 19"""
 20
 21import sys
 22from pathlib import Path
 23
 24sys.path.insert(0, str(Path(__file__).parent.parent))
 25
 26# Output directory for generated plots
 27OUTPUT_DIR = Path(__file__).parent.parent / "docs" / "_static" / "images" / "examples"
 28OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
 29
 30# Global flag to control plotting
 31SHOW_PLOTS = True
 32
 33import numpy as np  # noqa: E402
 34import plotly.graph_objects as go  # noqa: E402
 35from plotly.subplots import make_subplots  # noqa: E402
 36
 37from pytcl.dynamic_estimation.kalman.linear import kf_predict, kf_update  # noqa: E402
 38from pytcl.dynamic_estimation.particle_filters import (  # noqa: E402
 39    ParticleState,
 40    bootstrap_pf_predict,
 41    bootstrap_pf_step,
 42    bootstrap_pf_update,
 43    effective_sample_size,
 44    gaussian_likelihood,
 45    initialize_particles,
 46    particle_covariance,
 47    particle_mean,
 48    resample_multinomial,
 49    resample_residual,
 50    resample_systematic,
 51)
 52
 53
 54def demo_particle_basics():
 55    """Demonstrate basic particle filter operations."""
 56    print("=" * 70)
 57    print("Particle Filter Basics Demo")
 58    print("=" * 70)
 59
 60    np.random.seed(42)
 61
 62    # Initialize particles for a 2D state [x, y]
 63    n_particles = 1000
 64    state_dim = 2
 65
 66    # Initial distribution: Gaussian centered at origin
 67    mean = np.array([0.0, 0.0])
 68    cov = np.eye(2) * 2.0
 69
 70    # initialize_particles returns a ParticleState object
 71    state = initialize_particles(mean, cov, n_particles)
 72    particles = state.particles
 73    weights = state.weights
 74
 75    print(f"\nInitialized {n_particles} particles")
 76    print(f"State dimension: {state_dim}")
 77    print(f"Initial mean: {particle_mean(particles, weights)}")
 78    print(f"Initial std: {np.sqrt(np.diag(particle_covariance(particles, weights)))}")
 79
 80    # Effective sample size
 81    ess = effective_sample_size(weights)
 82    print(f"Initial ESS: {ess:.1f} (should be ~{n_particles})")
 83
 84    # Demonstrate weight degeneracy
 85    print("\n--- Weight Degeneracy Example ---")
 86    # Create skewed weights
 87    skewed_weights = np.ones(n_particles)
 88    skewed_weights[0] = 100.0  # One dominant particle
 89    skewed_weights /= skewed_weights.sum()
 90
 91    ess_skewed = effective_sample_size(skewed_weights)
 92    print(f"With one dominant particle, ESS: {ess_skewed:.1f}")
 93    print("This indicates severe weight degeneracy - resampling needed!")
 94
 95
 96def demo_resampling_methods():
 97    """Demonstrate different resampling strategies."""
 98    print("\n" + "=" * 70)
 99    print("Resampling Methods Demo")
100    print("=" * 70)
101
102    np.random.seed(42)
103
104    n_particles = 1000
105
106    # Create particles with non-uniform weights
107    particles = np.random.randn(n_particles, 2)
108    weights = np.exp(-np.sum(particles**2, axis=1) / 4)  # Higher near origin
109    weights /= weights.sum()
110
111    print(f"\nOriginal particle distribution:")
112    print(f"  Mean: {particle_mean(particles, weights)}")
113    print(f"  ESS: {effective_sample_size(weights):.1f}")
114
115    # Multinomial resampling - returns resampled particles directly
116    particles_multi = resample_multinomial(particles, weights)
117    weights_multi = np.ones(n_particles) / n_particles
118
119    print("\n--- Multinomial Resampling ---")
120    print(f"  Mean: {particle_mean(particles_multi, weights_multi)}")
121    print(f"  ESS: {effective_sample_size(weights_multi):.1f}")
122
123    # Systematic resampling (lower variance)
124    particles_sys = resample_systematic(particles, weights)
125    weights_sys = np.ones(n_particles) / n_particles
126
127    print("\n--- Systematic Resampling ---")
128    print(f"  Mean: {particle_mean(particles_sys, weights_sys)}")
129    print(f"  ESS: {effective_sample_size(weights_sys):.1f}")
130
131    # Residual resampling
132    particles_res = resample_residual(particles, weights)
133    weights_res = np.ones(n_particles) / n_particles
134
135    print("\n--- Residual Resampling ---")
136    print(f"  Mean: {particle_mean(particles_res, weights_res)}")
137    print(f"  ESS: {effective_sample_size(weights_res):.1f}")
138
139    print("\nNote: Systematic resampling typically preserves more diversity")
140    print("and has lower variance than multinomial resampling.")
141
142    # Plot resampling comparison
143    if SHOW_PLOTS:
144        fig = make_subplots(
145            rows=2,
146            cols=2,
147            subplot_titles=(
148                f"Original Particles (ESS={effective_sample_size(weights):.0f})",
149                "After Multinomial Resampling",
150                "After Systematic Resampling",
151                "After Residual Resampling",
152            ),
153        )
154
155        # Original particles with weights
156        fig.add_trace(
157            go.Scatter(
158                x=particles[:, 0],
159                y=particles[:, 1],
160                mode="markers",
161                marker=dict(size=5, color=weights, colorscale="Viridis", opacity=0.6),
162                name="Original",
163            ),
164            row=1,
165            col=1,
166        )
167
168        # Multinomial resampling
169        fig.add_trace(
170            go.Scatter(
171                x=particles_multi[:, 0],
172                y=particles_multi[:, 1],
173                mode="markers",
174                marker=dict(size=5, color="blue", opacity=0.6),
175                name="Multinomial",
176            ),
177            row=1,
178            col=2,
179        )
180
181        # Systematic resampling
182        fig.add_trace(
183            go.Scatter(
184                x=particles_sys[:, 0],
185                y=particles_sys[:, 1],
186                mode="markers",
187                marker=dict(size=5, color="green", opacity=0.6),
188                name="Systematic",
189            ),
190            row=2,
191            col=1,
192        )
193
194        # Residual resampling
195        fig.add_trace(
196            go.Scatter(
197                x=particles_res[:, 0],
198                y=particles_res[:, 1],
199                mode="markers",
200                marker=dict(size=5, color="red", opacity=0.6),
201                name="Residual",
202            ),
203            row=2,
204            col=2,
205        )
206
207        fig.update_layout(
208            height=800,
209            width=1000,
210            title_text="Comparison of Resampling Methods",
211            showlegend=False,
212        )
213        fig.update_xaxes(title_text="x")
214        fig.update_yaxes(title_text="y")
215
216        fig.write_html(str(OUTPUT_DIR / "particle_resampling_comparison.html"))
217        print("\n  [Plot saved to particle_resampling_comparison.html]")
218
219
220def demo_linear_tracking():
221    """Compare particle filter to Kalman filter for linear system."""
222    print("\n" + "=" * 70)
223    print("Linear System Tracking Demo")
224    print("=" * 70)
225
226    np.random.seed(42)
227
228    # Linear constant-velocity model
229    dt = 1.0
230    F = np.array(
231        [
232            [1, dt, 0, 0],
233            [0, 1, 0, 0],
234            [0, 0, 1, dt],
235            [0, 0, 0, 1],
236        ]
237    )
238
239    # Process noise
240    q = 0.1
241    Q = q * np.array(
242        [
243            [dt**3 / 3, dt**2 / 2, 0, 0],
244            [dt**2 / 2, dt, 0, 0],
245            [0, 0, dt**3 / 3, dt**2 / 2],
246            [0, 0, dt**2 / 2, dt],
247        ]
248    )
249
250    # Measurement model (observe position only)
251    H = np.array(
252        [
253            [1, 0, 0, 0],
254            [0, 0, 1, 0],
255        ]
256    )
257    R = np.eye(2) * 1.0
258
259    # True trajectory
260    n_steps = 20
261    true_states = np.zeros((n_steps, 4))
262    true_states[0] = [0, 1, 0, 0.5]  # Start at origin, moving diagonally
263
264    for k in range(1, n_steps):
265        true_states[k] = F @ true_states[k - 1] + np.random.multivariate_normal(
266            np.zeros(4), Q * 0.1
267        )
268
269    # Generate measurements
270    measurements = [
271        H @ true_states[k] + np.random.multivariate_normal(np.zeros(2), R)
272        for k in range(n_steps)
273    ]
274
275    print(f"\nSimulating {n_steps} time steps")
276    print("True initial state: [x=0, vx=1, y=0, vy=0.5]")
277
278    # Kalman filter
279    x_kf = np.array([0.0, 0.0, 0.0, 0.0])
280    P_kf = np.eye(4) * 10.0
281    kf_estimates = []
282
283    for z in measurements:
284        pred = kf_predict(x_kf, P_kf, F, Q)
285        upd = kf_update(pred.x, pred.P, z, H, R)
286        x_kf, P_kf = upd.x, upd.P
287        kf_estimates.append(x_kf.copy())
288
289    # Particle filter
290    n_particles = 500
291    state = initialize_particles(np.zeros(4), np.eye(4) * 10.0, n_particles)
292    particles = state.particles
293    weights = state.weights.copy()
294    pf_estimates = []
295
296    def process_fn(x):
297        return F @ x + np.random.multivariate_normal(np.zeros(4), Q)
298
299    def likelihood_fn(z, x):
300        z_pred = H @ x
301        return gaussian_likelihood(z, z_pred, R)
302
303    for z in measurements:
304        # Predict
305        particles = np.array([process_fn(p) for p in particles])
306
307        # Update weights
308        likelihoods = np.array([likelihood_fn(z, p) for p in particles])
309        weights = weights * likelihoods
310        weights /= weights.sum()
311
312        # Estimate
313        pf_estimates.append(particle_mean(particles, weights))
314
315        # Resample if needed
316        ess = effective_sample_size(weights)
317        if ess < n_particles / 2:
318            particles = resample_systematic(particles, weights)
319            weights = np.ones(n_particles) / n_particles
320
321    # Compare RMSE
322    kf_estimates = np.array(kf_estimates)
323    pf_estimates = np.array(pf_estimates)
324
325    kf_rmse = np.sqrt(np.mean((kf_estimates[:, [0, 2]] - true_states[:, [0, 2]]) ** 2))
326    pf_rmse = np.sqrt(np.mean((pf_estimates[:, [0, 2]] - true_states[:, [0, 2]]) ** 2))
327
328    print("\n--- Filter Comparison (Position RMSE) ---")
329    print(f"  Kalman Filter: {kf_rmse:.4f}")
330    print(f"  Particle Filter ({n_particles} particles): {pf_rmse:.4f}")
331    print("\nNote: For linear Gaussian systems, KF is optimal.")
332    print("PF approaches KF performance as particle count increases.")
333
334    # Plot tracking comparison
335    if SHOW_PLOTS:
336        fig = make_subplots(
337            rows=1,
338            cols=2,
339            subplot_titles=(
340                "Trajectory Tracking: KF vs PF",
341                "Position Error Over Time",
342            ),
343        )
344
345        measurements_arr = np.array(measurements)
346
347        # Trajectory plot
348        fig.add_trace(
349            go.Scatter(
350                x=true_states[:, 0],
351                y=true_states[:, 2],
352                mode="lines",
353                name="True trajectory",
354                line=dict(color="black", width=2),
355            ),
356            row=1,
357            col=1,
358        )
359        fig.add_trace(
360            go.Scatter(
361                x=kf_estimates[:, 0],
362                y=kf_estimates[:, 2],
363                mode="lines",
364                name=f"Kalman Filter (RMSE={kf_rmse:.3f})",
365                line=dict(color="blue", width=1.5, dash="dash"),
366            ),
367            row=1,
368            col=1,
369        )
370        fig.add_trace(
371            go.Scatter(
372                x=pf_estimates[:, 0],
373                y=pf_estimates[:, 2],
374                mode="lines",
375                name=f"Particle Filter (RMSE={pf_rmse:.3f})",
376                line=dict(color="red", width=1.5, dash="dot"),
377            ),
378            row=1,
379            col=1,
380        )
381        fig.add_trace(
382            go.Scatter(
383                x=measurements_arr[:, 0],
384                y=measurements_arr[:, 1],
385                mode="markers",
386                name="Measurements",
387                marker=dict(size=5, color="gray", opacity=0.5),
388            ),
389            row=1,
390            col=1,
391        )
392
393        # Error comparison
394        time = np.arange(n_steps)
395        kf_pos_err = np.sqrt(
396            (kf_estimates[:, 0] - true_states[:, 0]) ** 2
397            + (kf_estimates[:, 2] - true_states[:, 2]) ** 2
398        )
399        pf_pos_err = np.sqrt(
400            (pf_estimates[:, 0] - true_states[:, 0]) ** 2
401            + (pf_estimates[:, 2] - true_states[:, 2]) ** 2
402        )
403        fig.add_trace(
404            go.Scatter(
405                x=time,
406                y=kf_pos_err,
407                mode="lines",
408                name="Kalman Filter",
409                line=dict(color="blue"),
410            ),
411            row=1,
412            col=2,
413        )
414        fig.add_trace(
415            go.Scatter(
416                x=time,
417                y=pf_pos_err,
418                mode="lines",
419                name="Particle Filter",
420                line=dict(color="red"),
421            ),
422            row=1,
423            col=2,
424        )
425
426        fig.update_xaxes(title_text="x position", row=1, col=1)
427        fig.update_yaxes(title_text="y position", row=1, col=1)
428        fig.update_xaxes(title_text="Time step", row=1, col=2)
429        fig.update_yaxes(title_text="Position error", row=1, col=2)
430
431        fig.update_layout(height=500, width=1200)
432        fig.write_html(str(OUTPUT_DIR / "particle_linear_tracking.html"))
433        print("\n  [Plot saved to particle_linear_tracking.html]")
434
435
436def demo_nonlinear_tracking():
437    """Demonstrate particle filter for nonlinear system."""
438    print("\n" + "=" * 70)
439    print("Nonlinear System Tracking Demo")
440    print("=" * 70)
441
442    np.random.seed(42)
443
444    # Nonlinear dynamics: polar to Cartesian (range-bearing sensor)
445    # State: [x, y, vx, vy]
446    # Measurement: [range, bearing] (nonlinear!)
447
448    dt = 0.1
449    n_steps = 50
450    n_particles = 1000
451
452    # True trajectory: circular motion
453    omega = 0.5  # angular velocity
454    radius = 10.0
455    true_states = np.zeros((n_steps, 4))
456
457    for k in range(n_steps):
458        t = k * dt
459        true_states[k] = [
460            radius * np.cos(omega * t),
461            radius * np.sin(omega * t),
462            -radius * omega * np.sin(omega * t),
463            radius * omega * np.cos(omega * t),
464        ]
465
466    # Measurement noise
467    sigma_range = 0.5
468    sigma_bearing = np.radians(2.0)
469
470    def measurement_model(state):
471        """Nonlinear measurement: range and bearing from origin."""
472        x, y = state[0], state[1]
473        r = np.sqrt(x**2 + y**2)
474        theta = np.arctan2(y, x)
475        return np.array([r, theta])
476
477    # Generate measurements
478    measurements = []
479    for k in range(n_steps):
480        z_true = measurement_model(true_states[k])
481        noise = np.array(
482            [np.random.randn() * sigma_range, np.random.randn() * sigma_bearing]
483        )
484        measurements.append(z_true + noise)
485
486    print(f"\nSimulating circular motion with range-bearing sensor")
487    print(f"  Radius: {radius} m, Angular velocity: {omega} rad/s")
488    print(
489        f"  Measurement noise: sigma_r={sigma_range} m, "
490        f"sigma_theta={np.degrees(sigma_bearing):.1f} deg"
491    )
492
493    # Initialize particle filter
494    state = initialize_particles(
495        np.array([radius, 0.0, 0.0, radius * omega]),  # Near true initial
496        np.diag([1.0, 1.0, 0.5, 0.5]),
497        n_particles,
498    )
499    particles = state.particles
500    weights = state.weights.copy()
501
502    R = np.diag([sigma_range**2, sigma_bearing**2])
503
504    def process_fn(state):
505        """Constant velocity motion model with process noise for maneuvering."""
506        x, y, vx, vy = state
507        # Higher process noise to account for maneuvering (circular motion)
508        q_pos = 0.05  # Position noise
509        q_vel = 2.0  # Velocity noise (high to adapt to turning)
510        return np.array(
511            [
512                x + vx * dt + np.random.randn() * q_pos,
513                y + vy * dt + np.random.randn() * q_pos,
514                vx + np.random.randn() * q_vel * dt,
515                vy + np.random.randn() * q_vel * dt,
516            ]
517        )
518
519    # Run particle filter
520    pf_estimates = []
521    ess_history = []
522
523    for k, z in enumerate(measurements):
524        # Predict
525        particles = np.array([process_fn(p) for p in particles])
526
527        # Update weights using range-bearing likelihood
528        for i in range(n_particles):
529            z_pred = measurement_model(particles[i])
530            # Handle angle wraparound for bearing
531            z_wrapped = z.copy()
532            z_pred_wrapped = z_pred.copy()
533            # Normalize bearing difference
534            bearing_diff = np.arctan2(
535                np.sin(z[1] - z_pred[1]), np.cos(z[1] - z_pred[1])
536            )
537            z_wrapped[1] = z_pred[1] + bearing_diff
538            likelihood = gaussian_likelihood(z_wrapped, z_pred_wrapped, R)
539            weights[i] *= likelihood
540
541        # Normalize
542        if weights.sum() > 0:
543            weights /= weights.sum()
544        else:
545            weights = np.ones(n_particles) / n_particles
546
547        # Estimate
548        pf_estimates.append(particle_mean(particles, weights))
549        ess_history.append(effective_sample_size(weights))
550
551        # Resample
552        if ess_history[-1] < n_particles / 2:
553            particles = resample_systematic(particles, weights)
554            weights = np.ones(n_particles) / n_particles
555
556    pf_estimates = np.array(pf_estimates)
557
558    # Compute errors
559    pos_errors = np.sqrt(
560        (pf_estimates[:, 0] - true_states[:, 0]) ** 2
561        + (pf_estimates[:, 1] - true_states[:, 1]) ** 2
562    )
563
564    print("\n--- Tracking Results ---")
565    print(f"  Mean position error: {np.mean(pos_errors):.3f} m")
566    print(f"  Max position error: {np.max(pos_errors):.3f} m")
567    print(f"  Min ESS: {np.min(ess_history):.1f}")
568    print(f"  Mean ESS: {np.mean(ess_history):.1f}")
569
570    # Show trajectory snapshots
571    print("\n--- Trajectory Snapshots ---")
572    times = [0, n_steps // 4, n_steps // 2, 3 * n_steps // 4, n_steps - 1]
573    for t in times:
574        true_pos = true_states[t, :2]
575        est_pos = pf_estimates[t, :2]
576        err = pos_errors[t]
577        print(
578            f"  t={t*dt:.1f}s: True=({true_pos[0]:.2f}, {true_pos[1]:.2f}), "
579            f"Est=({est_pos[0]:.2f}, {est_pos[1]:.2f}), Err={err:.3f}m"
580        )
581
582    # Plot nonlinear tracking results
583    if SHOW_PLOTS:
584        fig = make_subplots(
585            rows=2,
586            cols=2,
587            subplot_titles=(
588                "Circular Motion Tracking with Range-Bearing Sensor",
589                "Position Error Over Time",
590                "ESS History (resampling when ESS < N/2)",
591                "Range-Bearing Measurements (color=time)",
592            ),
593        )
594
595        # Trajectory plot
596        fig.add_trace(
597            go.Scatter(
598                x=true_states[:, 0],
599                y=true_states[:, 1],
600                mode="lines",
601                name="True trajectory",
602                line=dict(color="black", width=2),
603            ),
604            row=1,
605            col=1,
606        )
607        fig.add_trace(
608            go.Scatter(
609                x=pf_estimates[:, 0],
610                y=pf_estimates[:, 1],
611                mode="lines",
612                name="PF estimate",
613                line=dict(color="red", width=1.5, dash="dash"),
614            ),
615            row=1,
616            col=1,
617        )
618        fig.add_trace(
619            go.Scatter(
620                x=[true_states[0, 0]],
621                y=[true_states[0, 1]],
622                mode="markers",
623                name="Start",
624                marker=dict(size=15, color="green", symbol="circle"),
625            ),
626            row=1,
627            col=1,
628        )
629        fig.add_trace(
630            go.Scatter(
631                x=[true_states[-1, 0]],
632                y=[true_states[-1, 1]],
633                mode="markers",
634                name="End",
635                marker=dict(size=15, color="blue", symbol="square"),
636            ),
637            row=1,
638            col=1,
639        )
640
641        # Position error over time
642        time_axis = np.arange(n_steps) * dt
643        fig.add_trace(
644            go.Scatter(
645                x=time_axis,
646                y=pos_errors,
647                mode="lines",
648                line=dict(color="blue", width=1.5),
649            ),
650            row=1,
651            col=2,
652        )
653        fig.add_hline(
654            y=np.mean(pos_errors),
655            line_dash="dash",
656            line_color="red",
657            annotation_text=f"Mean={np.mean(pos_errors):.3f}",
658            row=1,
659            col=2,
660        )
661
662        # ESS history
663        fig.add_trace(
664            go.Scatter(
665                x=time_axis,
666                y=ess_history,
667                mode="lines",
668                line=dict(color="green", width=1.5),
669            ),
670            row=2,
671            col=1,
672        )
673        fig.add_hline(
674            y=n_particles / 2,
675            line_dash="dash",
676            line_color="red",
677            annotation_text="Resampling threshold",
678            row=2,
679            col=1,
680        )
681
682        # Measurements in polar form
683        meas_arr = np.array(measurements)
684        fig.add_trace(
685            go.Scatter(
686                x=np.degrees(meas_arr[:, 1]),
687                y=meas_arr[:, 0],
688                mode="markers",
689                marker=dict(
690                    size=5,
691                    color=time_axis,
692                    colorscale="Viridis",
693                    colorbar=dict(title="Time (s)", x=1.0),
694                ),
695            ),
696            row=2,
697            col=2,
698        )
699
700        fig.update_xaxes(title_text="x position (m)", row=1, col=1)
701        fig.update_yaxes(title_text="y position (m)", row=1, col=1)
702        fig.update_xaxes(title_text="Time (s)", row=1, col=2)
703        fig.update_yaxes(title_text="Position error (m)", row=1, col=2)
704        fig.update_xaxes(title_text="Time (s)", row=2, col=1)
705        fig.update_yaxes(title_text="Effective Sample Size", row=2, col=1)
706        fig.update_xaxes(title_text="Bearing (degrees)", row=2, col=2)
707        fig.update_yaxes(title_text="Range (m)", row=2, col=2)
708
709        fig.update_layout(height=800, width=1000, showlegend=True)
710        fig.write_html(str(OUTPUT_DIR / "particle_nonlinear_tracking.html"))
711        print("\n  [Plot saved to particle_nonlinear_tracking.html]")
712
713
714def demo_multimodal():
715    """Demonstrate particle filter advantage for multimodal distributions."""
716    print("\n" + "=" * 70)
717    print("Multimodal Distribution Demo")
718    print("=" * 70)
719
720    np.random.seed(42)
721
722    # Scenario: Target could be at one of two locations
723    # This is impossible for a Kalman filter but natural for particle filters
724
725    n_particles = 2000
726
727    # Prior: mixture of two Gaussians
728    mode1 = np.array([5.0, 0.0])
729    mode2 = np.array([-5.0, 0.0])
730    cov = np.eye(2) * 0.5
731
732    # Initialize with bimodal distribution
733    state1 = initialize_particles(mode1, cov, n_particles // 2)
734    state2 = initialize_particles(mode2, cov, n_particles // 2)
735    particles = np.vstack([state1.particles, state2.particles])
736    weights = np.ones(n_particles) / n_particles
737
738    print("\nBimodal prior distribution:")
739    print(f"  Mode 1: {mode1}")
740    print(f"  Mode 2: {mode2}")
741    print(f"  Mean: {particle_mean(particles, weights)}")
742    print("  (Mean is between modes - not representative!)")
743
744    # Measurement that confirms mode 2
745    z = np.array([-4.8, 0.1])
746    R = np.eye(2) * 0.2
747
748    print(f"\nMeasurement received: {z}")
749
750    # Save prior particles for plotting
751    prior_particles = particles.copy()
752
753    # Update weights
754    for i in range(n_particles):
755        z_pred = particles[i]  # Direct position observation
756        weights[i] *= gaussian_likelihood(z, z_pred, R)
757    weights /= weights.sum()
758
759    # After update
760    print("\nAfter measurement update:")
761    print(f"  Mean: {particle_mean(particles, weights)}")
762    print(f"  ESS: {effective_sample_size(weights):.1f}")
763
764    # Analyze particle distribution
765    near_mode1 = np.sum(particles[:, 0] > 0)
766    near_mode2 = np.sum(particles[:, 0] < 0)
767    weight_mode1 = np.sum(weights[particles[:, 0] > 0])
768    weight_mode2 = np.sum(weights[particles[:, 0] < 0])
769
770    print(f"\n  Particles near mode 1: {near_mode1} (weight: {weight_mode1:.4f})")
771    print(f"  Particles near mode 2: {near_mode2} (weight: {weight_mode2:.4f})")
772    print("\nNote: PF correctly concentrates probability on mode 2")
773    print("after receiving the confirming measurement.")
774
775    # Plot multimodal distribution
776    if SHOW_PLOTS:
777        fig = make_subplots(
778            rows=1,
779            cols=2,
780            subplot_titles=(
781                "Prior: Bimodal Distribution",
782                "Posterior: After Measurement Update",
783            ),
784        )
785
786        # Prior distribution
787        fig.add_trace(
788            go.Scatter(
789                x=prior_particles[:, 0],
790                y=prior_particles[:, 1],
791                mode="markers",
792                marker=dict(size=3, color="blue", opacity=0.3),
793                name="Prior particles",
794            ),
795            row=1,
796            col=1,
797        )
798        fig.add_trace(
799            go.Scatter(
800                x=[mode1[0], mode2[0]],
801                y=[mode1[1], mode2[1]],
802                mode="markers",
803                marker=dict(size=15, color="green", symbol="x", line=dict(width=3)),
804                name="Modes",
805            ),
806            row=1,
807            col=1,
808        )
809
810        # Posterior distribution
811        fig.add_trace(
812            go.Scatter(
813                x=particles[:, 0],
814                y=particles[:, 1],
815                mode="markers",
816                marker=dict(
817                    size=weights * n_particles * 50,
818                    color=weights,
819                    colorscale="Reds",
820                    opacity=0.5,
821                ),
822                name="Posterior particles",
823            ),
824            row=1,
825            col=2,
826        )
827        fig.add_trace(
828            go.Scatter(
829                x=[z[0]],
830                y=[z[1]],
831                mode="markers",
832                marker=dict(size=20, color="blue", symbol="star"),
833                name="Measurement",
834            ),
835            row=1,
836            col=2,
837        )
838
839        fig.update_xaxes(range=[-10, 10], row=1, col=1)
840        fig.update_yaxes(range=[-5, 5], row=1, col=1)
841        fig.update_xaxes(range=[-10, 10], row=1, col=2)
842        fig.update_yaxes(range=[-5, 5], row=1, col=2)
843
844        fig.update_layout(
845            height=500,
846            width=1200,
847            title_text="Particle Filter for Multimodal Distribution (Point size proportional to weight)",
848        )
849        fig.write_html(str(OUTPUT_DIR / "particle_multimodal.html"))
850        print("\n  [Plot saved to particle_multimodal.html]")
851
852
853def main():
854    """Run all demonstrations."""
855    print("\n" + "#" * 70)
856    print("# PyTCL Particle Filters Example")
857    print("#" * 70)
858
859    demo_particle_basics()
860    demo_resampling_methods()
861    demo_linear_tracking()
862    demo_nonlinear_tracking()
863    demo_multimodal()
864
865    print("\n" + "=" * 70)
866    print("Example complete!")
867    if SHOW_PLOTS:
868        print("Plots saved: particle_resampling_comparison.html, ")
869        print("             particle_linear_tracking.html,")
870        print("             particle_nonlinear_tracking.html,")
871        print("             particle_multimodal.html")
872    print("=" * 70)
873
874
875if __name__ == "__main__":
876    main()

Running the Example

python examples/particle_filters.py

See Also