Spatial Data Structures

This example demonstrates spatial data structures for efficient queries.

Overview

Spatial data structures enable fast nearest-neighbor and range queries:

  • KD-Tree: k-dimensional binary search tree

  • R-Tree: Rectangle tree for bounding box queries

  • Ball Tree: Metric tree for arbitrary metrics

  • Cover Tree: Efficient for intrinsic dimensionality

Key Concepts

  • Nearest neighbor queries: Find k closest points

  • Range queries: Find all points within radius

  • Bulk loading: Efficient tree construction

  • Metric spaces: Distance-based operations

Data Structures

KD-Tree
  • Best for low-dimensional data (d < 20)

  • O(log n) average query time

  • Standard Euclidean distance

R-Tree
  • Designed for spatial indexing

  • Handles bounding boxes well

  • Good for GIS applications

Ball Tree
  • Works in any metric space

  • Better for high dimensions

  • Supports custom distance functions

Code Highlights

The example demonstrates:

  • Building KD-tree with KDTree()

  • Nearest neighbor queries with query()

  • Range queries with query_radius()

  • Bulk operations for efficiency

Source Code

  1"""
  2Spatial Data Structures Example
  3===============================
  4
  5This example demonstrates spatial data structures in PyTCL for
  6efficient nearest neighbor queries and spatial indexing:
  7
  8K-D Tree:
  9- Construction and querying
 10- K-nearest neighbor search
 11- Radius/range queries
 12
 13Ball Tree:
 14- Alternative to K-D tree for high dimensions
 15- Similar query interface
 16
 17R-Tree:
 18- Spatial indexing for bounding boxes
 19- Rectangle intersection queries
 20
 21VP-Tree (Vantage Point Tree):
 22- Metric space indexing
 23- Works with any distance metric
 24
 25Cover Tree:
 26- Approximate nearest neighbor search
 27- O(c^12 log n) query complexity
 28
 29These data structures are essential for efficient data association
 30in multi-target tracking and spatial analysis applications.
 31"""
 32
 33from pathlib import Path
 34
 35import numpy as np
 36import plotly.graph_objects as go
 37
 38# Output directory for generated plots
 39OUTPUT_DIR = Path(__file__).parent.parent / "docs" / "_static" / "images" / "examples"
 40OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
 41
 42# Global flag to control plotting
 43SHOW_PLOTS = True
 44
 45
 46from pytcl.containers import (  # K-D Tree; Ball Tree; R-Tree; VP-Tree; Cover Tree
 47    BallTree,
 48    BoundingBox,
 49    CoverTree,
 50    KDTree,
 51    NearestNeighborResult,
 52    RTree,
 53    VPTree,
 54    box_from_point,
 55    box_from_points,
 56    merge_boxes,
 57)
 58
 59
 60def demo_kdtree_basics():
 61    """Demonstrate K-D tree construction and basic queries."""
 62    print("=" * 70)
 63    print("K-D Tree Basics Demo")
 64    print("=" * 70)
 65
 66    np.random.seed(42)
 67
 68    # Generate 2D point cloud
 69    n_points = 100
 70    points = np.random.randn(n_points, 2) * 10
 71
 72    print(f"\nBuilding K-D tree with {n_points} 2D points")
 73
 74    # Build tree
 75    tree = KDTree(points)
 76
 77    print(f"Tree built successfully")
 78    print(f"Point cloud bounds:")
 79    print(f"  X: [{points[:, 0].min():.2f}, {points[:, 0].max():.2f}]")
 80    print(f"  Y: [{points[:, 1].min():.2f}, {points[:, 1].max():.2f}]")
 81
 82    # Query point
 83    query = np.array([0.0, 0.0])
 84    print(f"\nQuery point: {query}")
 85
 86    # K-nearest neighbors
 87    k = 5
 88    result = tree.query(query, k=k)
 89
 90    print(f"\n{k} nearest neighbors:")
 91    # indices/distances are 2D arrays, extract the first row for single query
 92    for i, (idx, dist) in enumerate(zip(result.indices[0], result.distances[0])):
 93        print(f"  {i+1}. Point {idx}: {points[idx]} (distance={dist:.4f})")
 94
 95    # Plot KDTree result
 96    if SHOW_PLOTS:
 97        fig = go.Figure()
 98
 99        # All points
100        fig.add_trace(
101            go.Scatter(
102                x=points[:, 0],
103                y=points[:, 1],
104                mode="markers",
105                marker=dict(color="lightblue", size=8, opacity=0.6),
106                name="Points",
107            )
108        )
109
110        # K nearest neighbors
111        nn_indices = result.indices[0]
112        fig.add_trace(
113            go.Scatter(
114                x=points[nn_indices, 0],
115                y=points[nn_indices, 1],
116                mode="markers",
117                marker=dict(color="green", size=12, opacity=0.8),
118                name=f"{k} nearest neighbors",
119            )
120        )
121
122        # Query point
123        fig.add_trace(
124            go.Scatter(
125                x=[query[0]],
126                y=[query[1]],
127                mode="markers",
128                marker=dict(color="red", size=15, symbol="star"),
129                name="Query",
130            )
131        )
132
133        # Draw circle for max distance
134        max_dist = result.distances[0, -1]
135        theta = np.linspace(0, 2 * np.pi, 100)
136        circle_x = query[0] + max_dist * np.cos(theta)
137        circle_y = query[1] + max_dist * np.sin(theta)
138        fig.add_trace(
139            go.Scatter(
140                x=circle_x,
141                y=circle_y,
142                mode="lines",
143                line=dict(color="green", dash="dash", width=2),
144                name="Search radius",
145                showlegend=True,
146            )
147        )
148
149        fig.update_layout(
150            title="K-D Tree: K-Nearest Neighbor Query",
151            xaxis_title="x",
152            yaxis_title="y",
153            height=600,
154            width=600,
155            showlegend=True,
156            xaxis=dict(scaleanchor="y", scaleratio=1),
157        )
158        fig.write_html(str(OUTPUT_DIR / "spatial_kdtree.html"))
159        print("\n  [Plot saved to spatial_kdtree.html]")
160
161
162def demo_kdtree_queries():
163    """Demonstrate K-D tree query types."""
164    print("\n" + "=" * 70)
165    print("K-D Tree Query Types Demo")
166    print("=" * 70)
167
168    np.random.seed(42)
169
170    # Create structured point cloud
171    # Grid + noise
172    x = np.linspace(-10, 10, 11)
173    y = np.linspace(-10, 10, 11)
174    xx, yy = np.meshgrid(x, y)
175    points = np.column_stack([xx.ravel(), yy.ravel()])
176    points += np.random.randn(*points.shape) * 0.3
177
178    tree = KDTree(points)
179
180    print(f"\nGrid-based point cloud: {len(points)} points")
181
182    # Different query types
183    query = np.array([0.0, 0.0])
184
185    # K-NN query
186    k_values = [1, 5, 10, 20]
187    print("\n--- K-Nearest Neighbors ---")
188    for k in k_values:
189        result = tree.query(query, k=k)
190        max_dist = result.distances[0, -1]  # 2D array: [query_idx, neighbor_idx]
191        print(f"  k={k:>2}: max distance = {max_dist:.4f}")
192
193    # Radius query
194    print("\n--- Radius Queries ---")
195    radii = [1.0, 2.0, 5.0, 10.0]
196    for r in radii:
197        result = tree.query_radius(query, r)
198        # query_radius returns list of index arrays (one per query point)
199        print(f"  radius={r:.1f}: {len(result[0])} points found")
200
201
202def demo_balltree():
203    """Demonstrate Ball Tree for higher dimensions."""
204    print("\n" + "=" * 70)
205    print("Ball Tree Demo")
206    print("=" * 70)
207
208    np.random.seed(42)
209
210    # Higher dimensional data
211    n_points = 500
212    n_dims = 10  # 10-dimensional space
213
214    points = np.random.randn(n_points, n_dims)
215
216    print(f"\nBuilding Ball Tree with {n_points} points in {n_dims}D space")
217
218    tree = BallTree(points)
219
220    # Query
221    query = np.zeros(n_dims)
222    k = 5
223
224    result = tree.query(query, k=k)
225
226    print(f"\nQuery: origin in {n_dims}D")
227    print(f"{k} nearest neighbors:")
228    for i, (idx, dist) in enumerate(zip(result.indices[0], result.distances[0])):
229        print(f"  {i+1}. Point {idx}: distance = {dist:.4f}")
230
231    print("\nNote: Ball Tree is often more efficient than K-D Tree")
232    print("for higher dimensional data (curse of dimensionality).")
233
234
235def demo_rtree():
236    """Demonstrate R-Tree for bounding box indexing."""
237    print("\n" + "=" * 70)
238    print("R-Tree Demo")
239    print("=" * 70)
240
241    np.random.seed(42)
242
243    # Create bounding boxes (e.g., for spatial objects)
244    n_boxes = 50
245    boxes = []
246
247    for i in range(n_boxes):
248        # Random center and size
249        center = np.random.uniform(-50, 50, 2)
250        size = np.random.uniform(2, 10, 2)
251
252        min_coords = center - size / 2
253        max_coords = center + size / 2
254
255        box = BoundingBox(min_coords=min_coords, max_coords=max_coords)
256        boxes.append(box)
257
258    print(f"\nCreated {n_boxes} bounding boxes")
259
260    # Build R-Tree
261    tree = RTree()
262    for i, box in enumerate(boxes):
263        tree.insert(box, i)
264
265    print("R-Tree built successfully")
266
267    # Query: find boxes intersecting a search region
268    search_min = np.array([-10, -10])
269    search_max = np.array([10, 10])
270    search_box = BoundingBox(min_coords=search_min, max_coords=search_max)
271
272    print(f"\nSearch region: ({search_min} to {search_max})")
273
274    result = tree.query_intersect(search_box)
275
276    print(f"Found {len(result.indices)} intersecting boxes")
277
278    # Show some results
279    if len(result.indices) > 0:
280        print("\nFirst 5 intersecting boxes:")
281        for idx in result.indices[:5]:
282            box = boxes[idx]
283            print(f"  Box {idx}: ({box.min_coords} to {box.max_coords})")
284
285    # Plot R-Tree result
286    if SHOW_PLOTS:
287        fig = go.Figure()
288
289        # Draw all boxes
290        for i, box in enumerate(boxes):
291            is_intersecting = i in result.indices
292            color = "green" if is_intersecting else "lightblue"
293            opacity = 0.6 if is_intersecting else 0.3
294
295            # Create rectangle as a filled shape
296            x0, y0 = box.min_coords
297            x1, y1 = box.max_coords
298
299            fig.add_shape(
300                type="rect",
301                x0=x0,
302                y0=y0,
303                x1=x1,
304                y1=y1,
305                fillcolor=color,
306                line=dict(color="black", width=1),
307                opacity=opacity,
308            )
309
310        # Draw search region
311        fig.add_shape(
312            type="rect",
313            x0=search_min[0],
314            y0=search_min[1],
315            x1=search_max[0],
316            y1=search_max[1],
317            fillcolor="rgba(0,0,0,0)",
318            line=dict(color="red", width=3, dash="dash"),
319        )
320
321        # Add legend traces (invisible points for legend)
322        fig.add_trace(
323            go.Scatter(
324                x=[None],
325                y=[None],
326                mode="markers",
327                marker=dict(size=15, color="green", opacity=0.6),
328                name="Intersecting",
329            )
330        )
331        fig.add_trace(
332            go.Scatter(
333                x=[None],
334                y=[None],
335                mode="markers",
336                marker=dict(size=15, color="lightblue", opacity=0.3),
337                name="Non-intersecting",
338            )
339        )
340        fig.add_trace(
341            go.Scatter(
342                x=[None],
343                y=[None],
344                mode="lines",
345                line=dict(color="red", width=3, dash="dash"),
346                name="Search region",
347            )
348        )
349
350        fig.update_layout(
351            title=f"R-Tree: {len(result.indices)} boxes intersecting search region",
352            xaxis_title="x",
353            yaxis_title="y",
354            height=700,
355            width=700,
356            showlegend=True,
357            xaxis=dict(range=[-60, 60], scaleanchor="y", scaleratio=1),
358            yaxis=dict(range=[-60, 60]),
359        )
360        fig.write_html(str(OUTPUT_DIR / "spatial_rtree.html"))
361        print("\n  [Plot saved to spatial_rtree.html]")
362
363
364def demo_bounding_box_operations():
365    """Demonstrate bounding box utility functions."""
366    print("\n" + "=" * 70)
367    print("Bounding Box Operations Demo")
368    print("=" * 70)
369
370    # Create boxes
371    box1 = BoundingBox(min_coords=np.array([0, 0]), max_coords=np.array([5, 5]))
372
373    box2 = BoundingBox(min_coords=np.array([3, 3]), max_coords=np.array([8, 8]))
374
375    box3 = BoundingBox(min_coords=np.array([10, 10]), max_coords=np.array([12, 12]))
376
377    print("\nBox 1: (0,0) to (5,5)")
378    print(f"  Center: {box1.center}")
379    print(f"  Dimensions: {box1.dimensions}")
380    print(f"  Volume: {box1.volume}")
381
382    print("\nBox 2: (3,3) to (8,8)")
383    print(f"  Center: {box2.center}")
384
385    print("\nBox 3: (10,10) to (12,12)")
386    print(f"  Center: {box3.center}")
387
388    # Intersection tests
389    print("\n--- Intersection Tests ---")
390    print(f"  Box1 intersects Box2: {box1.intersects(box2)}")
391    print(f"  Box1 intersects Box3: {box1.intersects(box3)}")
392    print(f"  Box2 intersects Box3: {box2.intersects(box3)}")
393
394    # Point containment
395    test_points = [
396        np.array([2.5, 2.5]),
397        np.array([4.0, 4.0]),
398        np.array([7.0, 7.0]),
399    ]
400
401    print("\n--- Point Containment Tests ---")
402    for p in test_points:
403        print(f"  Point {p}:")
404        print(f"    In Box1: {box1.contains_point(p)}")
405        print(f"    In Box2: {box2.contains_point(p)}")
406
407    # Merge boxes
408    merged = merge_boxes([box1, box2])  # Takes a list of boxes
409    print("\n--- Merged Box (Box1 + Box2) ---")
410    print(f"  Min: {merged.min_coords}")
411    print(f"  Max: {merged.max_coords}")
412
413    # Create box from points
414    points = np.array([[1, 2], [5, 3], [2, 8], [7, 4]])
415    bbox = box_from_points(points)
416    print("\n--- Bounding Box of Points ---")
417    print(f"  Points:\n{points}")
418    print(f"  Bounding box: ({bbox.min_coords} to {bbox.max_coords})")
419
420
421def demo_vptree():
422    """Demonstrate VP-Tree for metric space indexing."""
423    print("\n" + "=" * 70)
424    print("VP-Tree Demo")
425    print("=" * 70)
426
427    np.random.seed(42)
428
429    # Generate points
430    n_points = 200
431    points = np.random.randn(n_points, 3) * 5
432
433    print(f"\nBuilding VP-Tree with {n_points} 3D points")
434
435    tree = VPTree(points)
436
437    # Query
438    query = np.array([1.0, 1.0, 1.0])
439    k = 5
440
441    result = tree.query(query, k=k)
442
443    print(f"\nQuery point: {query}")
444    print(f"{k} nearest neighbors:")
445    for i, (idx, dist) in enumerate(zip(result.indices[0], result.distances[0])):
446        print(f"  {i+1}. Point {idx}: distance = {dist:.4f}")
447
448    print("\nNote: VP-Tree works with any distance metric,")
449    print("not just Euclidean distance.")
450
451
452def demo_covertree():
453    """Demonstrate Cover Tree for approximate nearest neighbor."""
454    print("\n" + "=" * 70)
455    print("Cover Tree Demo")
456    print("=" * 70)
457
458    np.random.seed(42)
459
460    # Generate points
461    n_points = 300
462    points = np.random.randn(n_points, 4) * 3  # 4D
463
464    print(f"\nBuilding Cover Tree with {n_points} 4D points")
465
466    tree = CoverTree(points)
467
468    # Query
469    query = np.zeros(4)
470    k = 5
471
472    result = tree.query(query, k=k)
473
474    print(f"\nQuery: origin in 4D")
475    print(f"{k} nearest neighbors:")
476    for i, (idx, dist) in enumerate(zip(result.indices[0], result.distances[0])):
477        print(f"  {i+1}. Point {idx}: distance = {dist:.4f}")
478
479    print("\nNote: Cover Tree provides O(c^12 log n) query complexity")
480    print("where c is the expansion constant of the data.")
481
482
483def demo_performance_comparison():
484    """Compare performance of different spatial data structures."""
485    print("\n" + "=" * 70)
486    print("Performance Comparison Demo")
487    print("=" * 70)
488
489    import time
490
491    np.random.seed(42)
492
493    # Test data
494    n_points = 5000
495    n_queries = 100
496    dims = 3
497    k = 10
498
499    points = np.random.randn(n_points, dims) * 10
500    queries = np.random.randn(n_queries, dims) * 10
501
502    print(f"\nDataset: {n_points} points in {dims}D")
503    print(f"Queries: {n_queries} k-NN queries (k={k})")
504
505    results = {}
506
507    # K-D Tree
508    t0 = time.time()
509    kdtree = KDTree(points)
510    build_time = time.time() - t0
511
512    t0 = time.time()
513    for q in queries:
514        kdtree.query(q, k=k)
515    query_time = time.time() - t0
516
517    results["K-D Tree"] = (build_time, query_time)
518
519    # Ball Tree
520    t0 = time.time()
521    balltree = BallTree(points)
522    build_time = time.time() - t0
523
524    t0 = time.time()
525    for q in queries:
526        balltree.query(q, k=k)
527    query_time = time.time() - t0
528
529    results["Ball Tree"] = (build_time, query_time)
530
531    # VP Tree
532    t0 = time.time()
533    vptree = VPTree(points)
534    build_time = time.time() - t0
535
536    t0 = time.time()
537    for q in queries:
538        vptree.query(q, k=k)
539    query_time = time.time() - t0
540
541    results["VP-Tree"] = (build_time, query_time)
542
543    # Cover Tree
544    t0 = time.time()
545    covertree = CoverTree(points)
546    build_time = time.time() - t0
547
548    t0 = time.time()
549    for q in queries:
550        covertree.query(q, k=k)
551    query_time = time.time() - t0
552
553    results["Cover Tree"] = (build_time, query_time)
554
555    # Print results
556    print("\n" + "-" * 50)
557    print(f"{'Structure':<15} {'Build (ms)':>12} {'Query (ms)':>12}")
558    print("-" * 50)
559    for name, (build, query) in results.items():
560        print(f"{name:<15} {build*1000:>12.2f} {query*1000:>12.2f}")
561
562    print("\nNote: Performance depends on data distribution and dimensionality.")
563
564
565def demo_tracking_application():
566    """Demonstrate spatial indexing in tracking context."""
567    print("\n" + "=" * 70)
568    print("Tracking Application Demo")
569    print("=" * 70)
570
571    np.random.seed(42)
572
573    # Simulated scenario: sensor provides measurements,
574    # need to associate with predicted track positions
575
576    # Track predictions
577    n_tracks = 20
578    track_positions = np.random.uniform(-100, 100, (n_tracks, 2))
579
580    # Measurements (some from tracks, some false alarms)
581    n_measurements = 30
582    # First n_tracks measurements near track positions
583    measurements = np.zeros((n_measurements, 2))
584    for i in range(min(n_tracks, n_measurements)):
585        measurements[i] = track_positions[i] + np.random.randn(2) * 2.0
586
587    # Remaining are false alarms
588    for i in range(n_tracks, n_measurements):
589        measurements[i] = np.random.uniform(-100, 100, 2)
590
591    print(f"\n{n_tracks} track predictions")
592    print(
593        f"{n_measurements} measurements ({n_tracks} true + "
594        f"{n_measurements - n_tracks} false alarms)"
595    )
596
597    # Build spatial index on track predictions
598    tree = KDTree(track_positions)
599
600    # For each measurement, find nearest track
601    print("\nMeasurement-to-track association using K-D tree:")
602    print("-" * 50)
603
604    gating_threshold = 5.0  # meters
605    associations = []
606
607    for m_idx, meas in enumerate(measurements):
608        result = tree.query(meas, k=1)
609        nearest_track = result.indices[0, 0]  # 2D array [query_idx, neighbor_idx]
610        distance = result.distances[0, 0]
611
612        if distance < gating_threshold:
613            associations.append((m_idx, nearest_track, distance))
614
615    print(f"Gating threshold: {gating_threshold} m")
616    print(f"Measurements passing gate: {len(associations)}/{n_measurements}")
617
618    # Show some associations
619    print("\nFirst 5 associations:")
620    for m_idx, t_idx, dist in associations[:5]:
621        true_assoc = m_idx == t_idx  # Simplified ground truth
622        status = "+" if true_assoc else "?"
623        print(f"  Meas {m_idx:>2} -> Track {t_idx:>2} " f"(dist={dist:.2f}) {status}")
624
625    # Radius query for gating
626    print("\n--- Using Radius Query for Gating ---")
627    meas_test = measurements[0]
628    result = tree.query_radius(meas_test, gating_threshold)
629    # query_radius returns list of index arrays (one per query point)
630    print(f"Measurement 0: {len(result[0])} tracks within gate")
631
632
633def main():
634    """Run all demonstrations."""
635    print("\n" + "#" * 70)
636    print("# PyTCL Spatial Data Structures Example")
637    print("#" * 70)
638
639    demo_kdtree_basics()
640    demo_kdtree_queries()
641    demo_balltree()
642    demo_rtree()
643    demo_bounding_box_operations()
644    demo_vptree()
645    demo_covertree()
646    demo_performance_comparison()
647    demo_tracking_application()
648
649    print("\n" + "=" * 70)
650    print("Example complete!")
651    if SHOW_PLOTS:
652        print("Plots saved: spatial_kdtree.html, spatial_rtree.html")
653    print("=" * 70)
654
655
656if __name__ == "__main__":
657    main()

Running the Example

python examples/spatial_data_structures.py

See Also