Spatial Data Structures
This example demonstrates spatial data structures for efficient queries.
Overview
Spatial data structures enable fast nearest-neighbor and range queries:
KD-Tree: k-dimensional binary search tree
R-Tree: Rectangle tree for bounding box queries
Ball Tree: Metric tree for arbitrary metrics
Cover Tree: Efficient for intrinsic dimensionality
Key Concepts
Nearest neighbor queries: Find k closest points
Range queries: Find all points within radius
Bulk loading: Efficient tree construction
Metric spaces: Distance-based operations
Data Structures
- KD-Tree
Best for low-dimensional data (d < 20)
O(log n) average query time
Standard Euclidean distance
- R-Tree
Designed for spatial indexing
Handles bounding boxes well
Good for GIS applications
- Ball Tree
Works in any metric space
Better for high dimensions
Supports custom distance functions
Code Highlights
The example demonstrates:
Building KD-tree with
KDTree()Nearest neighbor queries with
query()Range queries with
query_radius()Bulk operations for efficiency
Source Code
1"""
2Spatial Data Structures Example
3===============================
4
5This example demonstrates spatial data structures in PyTCL for
6efficient nearest neighbor queries and spatial indexing:
7
8K-D Tree:
9- Construction and querying
10- K-nearest neighbor search
11- Radius/range queries
12
13Ball Tree:
14- Alternative to K-D tree for high dimensions
15- Similar query interface
16
17R-Tree:
18- Spatial indexing for bounding boxes
19- Rectangle intersection queries
20
21VP-Tree (Vantage Point Tree):
22- Metric space indexing
23- Works with any distance metric
24
25Cover Tree:
26- Approximate nearest neighbor search
27- O(c^12 log n) query complexity
28
29These data structures are essential for efficient data association
30in multi-target tracking and spatial analysis applications.
31"""
32
33from pathlib import Path
34
35import numpy as np
36import plotly.graph_objects as go
37
38# Output directory for generated plots
39OUTPUT_DIR = Path(__file__).parent.parent / "docs" / "_static" / "images" / "examples"
40OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
41
42# Global flag to control plotting
43SHOW_PLOTS = True
44
45
46from pytcl.containers import ( # K-D Tree; Ball Tree; R-Tree; VP-Tree; Cover Tree
47 BallTree,
48 BoundingBox,
49 CoverTree,
50 KDTree,
51 NearestNeighborResult,
52 RTree,
53 VPTree,
54 box_from_point,
55 box_from_points,
56 merge_boxes,
57)
58
59
60def demo_kdtree_basics():
61 """Demonstrate K-D tree construction and basic queries."""
62 print("=" * 70)
63 print("K-D Tree Basics Demo")
64 print("=" * 70)
65
66 np.random.seed(42)
67
68 # Generate 2D point cloud
69 n_points = 100
70 points = np.random.randn(n_points, 2) * 10
71
72 print(f"\nBuilding K-D tree with {n_points} 2D points")
73
74 # Build tree
75 tree = KDTree(points)
76
77 print(f"Tree built successfully")
78 print(f"Point cloud bounds:")
79 print(f" X: [{points[:, 0].min():.2f}, {points[:, 0].max():.2f}]")
80 print(f" Y: [{points[:, 1].min():.2f}, {points[:, 1].max():.2f}]")
81
82 # Query point
83 query = np.array([0.0, 0.0])
84 print(f"\nQuery point: {query}")
85
86 # K-nearest neighbors
87 k = 5
88 result = tree.query(query, k=k)
89
90 print(f"\n{k} nearest neighbors:")
91 # indices/distances are 2D arrays, extract the first row for single query
92 for i, (idx, dist) in enumerate(zip(result.indices[0], result.distances[0])):
93 print(f" {i+1}. Point {idx}: {points[idx]} (distance={dist:.4f})")
94
95 # Plot KDTree result
96 if SHOW_PLOTS:
97 fig = go.Figure()
98
99 # All points
100 fig.add_trace(
101 go.Scatter(
102 x=points[:, 0],
103 y=points[:, 1],
104 mode="markers",
105 marker=dict(color="lightblue", size=8, opacity=0.6),
106 name="Points",
107 )
108 )
109
110 # K nearest neighbors
111 nn_indices = result.indices[0]
112 fig.add_trace(
113 go.Scatter(
114 x=points[nn_indices, 0],
115 y=points[nn_indices, 1],
116 mode="markers",
117 marker=dict(color="green", size=12, opacity=0.8),
118 name=f"{k} nearest neighbors",
119 )
120 )
121
122 # Query point
123 fig.add_trace(
124 go.Scatter(
125 x=[query[0]],
126 y=[query[1]],
127 mode="markers",
128 marker=dict(color="red", size=15, symbol="star"),
129 name="Query",
130 )
131 )
132
133 # Draw circle for max distance
134 max_dist = result.distances[0, -1]
135 theta = np.linspace(0, 2 * np.pi, 100)
136 circle_x = query[0] + max_dist * np.cos(theta)
137 circle_y = query[1] + max_dist * np.sin(theta)
138 fig.add_trace(
139 go.Scatter(
140 x=circle_x,
141 y=circle_y,
142 mode="lines",
143 line=dict(color="green", dash="dash", width=2),
144 name="Search radius",
145 showlegend=True,
146 )
147 )
148
149 fig.update_layout(
150 title="K-D Tree: K-Nearest Neighbor Query",
151 xaxis_title="x",
152 yaxis_title="y",
153 height=600,
154 width=600,
155 showlegend=True,
156 xaxis=dict(scaleanchor="y", scaleratio=1),
157 )
158 fig.write_html(str(OUTPUT_DIR / "spatial_kdtree.html"))
159 print("\n [Plot saved to spatial_kdtree.html]")
160
161
162def demo_kdtree_queries():
163 """Demonstrate K-D tree query types."""
164 print("\n" + "=" * 70)
165 print("K-D Tree Query Types Demo")
166 print("=" * 70)
167
168 np.random.seed(42)
169
170 # Create structured point cloud
171 # Grid + noise
172 x = np.linspace(-10, 10, 11)
173 y = np.linspace(-10, 10, 11)
174 xx, yy = np.meshgrid(x, y)
175 points = np.column_stack([xx.ravel(), yy.ravel()])
176 points += np.random.randn(*points.shape) * 0.3
177
178 tree = KDTree(points)
179
180 print(f"\nGrid-based point cloud: {len(points)} points")
181
182 # Different query types
183 query = np.array([0.0, 0.0])
184
185 # K-NN query
186 k_values = [1, 5, 10, 20]
187 print("\n--- K-Nearest Neighbors ---")
188 for k in k_values:
189 result = tree.query(query, k=k)
190 max_dist = result.distances[0, -1] # 2D array: [query_idx, neighbor_idx]
191 print(f" k={k:>2}: max distance = {max_dist:.4f}")
192
193 # Radius query
194 print("\n--- Radius Queries ---")
195 radii = [1.0, 2.0, 5.0, 10.0]
196 for r in radii:
197 result = tree.query_radius(query, r)
198 # query_radius returns list of index arrays (one per query point)
199 print(f" radius={r:.1f}: {len(result[0])} points found")
200
201
202def demo_balltree():
203 """Demonstrate Ball Tree for higher dimensions."""
204 print("\n" + "=" * 70)
205 print("Ball Tree Demo")
206 print("=" * 70)
207
208 np.random.seed(42)
209
210 # Higher dimensional data
211 n_points = 500
212 n_dims = 10 # 10-dimensional space
213
214 points = np.random.randn(n_points, n_dims)
215
216 print(f"\nBuilding Ball Tree with {n_points} points in {n_dims}D space")
217
218 tree = BallTree(points)
219
220 # Query
221 query = np.zeros(n_dims)
222 k = 5
223
224 result = tree.query(query, k=k)
225
226 print(f"\nQuery: origin in {n_dims}D")
227 print(f"{k} nearest neighbors:")
228 for i, (idx, dist) in enumerate(zip(result.indices[0], result.distances[0])):
229 print(f" {i+1}. Point {idx}: distance = {dist:.4f}")
230
231 print("\nNote: Ball Tree is often more efficient than K-D Tree")
232 print("for higher dimensional data (curse of dimensionality).")
233
234
235def demo_rtree():
236 """Demonstrate R-Tree for bounding box indexing."""
237 print("\n" + "=" * 70)
238 print("R-Tree Demo")
239 print("=" * 70)
240
241 np.random.seed(42)
242
243 # Create bounding boxes (e.g., for spatial objects)
244 n_boxes = 50
245 boxes = []
246
247 for i in range(n_boxes):
248 # Random center and size
249 center = np.random.uniform(-50, 50, 2)
250 size = np.random.uniform(2, 10, 2)
251
252 min_coords = center - size / 2
253 max_coords = center + size / 2
254
255 box = BoundingBox(min_coords=min_coords, max_coords=max_coords)
256 boxes.append(box)
257
258 print(f"\nCreated {n_boxes} bounding boxes")
259
260 # Build R-Tree
261 tree = RTree()
262 for i, box in enumerate(boxes):
263 tree.insert(box, i)
264
265 print("R-Tree built successfully")
266
267 # Query: find boxes intersecting a search region
268 search_min = np.array([-10, -10])
269 search_max = np.array([10, 10])
270 search_box = BoundingBox(min_coords=search_min, max_coords=search_max)
271
272 print(f"\nSearch region: ({search_min} to {search_max})")
273
274 result = tree.query_intersect(search_box)
275
276 print(f"Found {len(result.indices)} intersecting boxes")
277
278 # Show some results
279 if len(result.indices) > 0:
280 print("\nFirst 5 intersecting boxes:")
281 for idx in result.indices[:5]:
282 box = boxes[idx]
283 print(f" Box {idx}: ({box.min_coords} to {box.max_coords})")
284
285 # Plot R-Tree result
286 if SHOW_PLOTS:
287 fig = go.Figure()
288
289 # Draw all boxes
290 for i, box in enumerate(boxes):
291 is_intersecting = i in result.indices
292 color = "green" if is_intersecting else "lightblue"
293 opacity = 0.6 if is_intersecting else 0.3
294
295 # Create rectangle as a filled shape
296 x0, y0 = box.min_coords
297 x1, y1 = box.max_coords
298
299 fig.add_shape(
300 type="rect",
301 x0=x0,
302 y0=y0,
303 x1=x1,
304 y1=y1,
305 fillcolor=color,
306 line=dict(color="black", width=1),
307 opacity=opacity,
308 )
309
310 # Draw search region
311 fig.add_shape(
312 type="rect",
313 x0=search_min[0],
314 y0=search_min[1],
315 x1=search_max[0],
316 y1=search_max[1],
317 fillcolor="rgba(0,0,0,0)",
318 line=dict(color="red", width=3, dash="dash"),
319 )
320
321 # Add legend traces (invisible points for legend)
322 fig.add_trace(
323 go.Scatter(
324 x=[None],
325 y=[None],
326 mode="markers",
327 marker=dict(size=15, color="green", opacity=0.6),
328 name="Intersecting",
329 )
330 )
331 fig.add_trace(
332 go.Scatter(
333 x=[None],
334 y=[None],
335 mode="markers",
336 marker=dict(size=15, color="lightblue", opacity=0.3),
337 name="Non-intersecting",
338 )
339 )
340 fig.add_trace(
341 go.Scatter(
342 x=[None],
343 y=[None],
344 mode="lines",
345 line=dict(color="red", width=3, dash="dash"),
346 name="Search region",
347 )
348 )
349
350 fig.update_layout(
351 title=f"R-Tree: {len(result.indices)} boxes intersecting search region",
352 xaxis_title="x",
353 yaxis_title="y",
354 height=700,
355 width=700,
356 showlegend=True,
357 xaxis=dict(range=[-60, 60], scaleanchor="y", scaleratio=1),
358 yaxis=dict(range=[-60, 60]),
359 )
360 fig.write_html(str(OUTPUT_DIR / "spatial_rtree.html"))
361 print("\n [Plot saved to spatial_rtree.html]")
362
363
364def demo_bounding_box_operations():
365 """Demonstrate bounding box utility functions."""
366 print("\n" + "=" * 70)
367 print("Bounding Box Operations Demo")
368 print("=" * 70)
369
370 # Create boxes
371 box1 = BoundingBox(min_coords=np.array([0, 0]), max_coords=np.array([5, 5]))
372
373 box2 = BoundingBox(min_coords=np.array([3, 3]), max_coords=np.array([8, 8]))
374
375 box3 = BoundingBox(min_coords=np.array([10, 10]), max_coords=np.array([12, 12]))
376
377 print("\nBox 1: (0,0) to (5,5)")
378 print(f" Center: {box1.center}")
379 print(f" Dimensions: {box1.dimensions}")
380 print(f" Volume: {box1.volume}")
381
382 print("\nBox 2: (3,3) to (8,8)")
383 print(f" Center: {box2.center}")
384
385 print("\nBox 3: (10,10) to (12,12)")
386 print(f" Center: {box3.center}")
387
388 # Intersection tests
389 print("\n--- Intersection Tests ---")
390 print(f" Box1 intersects Box2: {box1.intersects(box2)}")
391 print(f" Box1 intersects Box3: {box1.intersects(box3)}")
392 print(f" Box2 intersects Box3: {box2.intersects(box3)}")
393
394 # Point containment
395 test_points = [
396 np.array([2.5, 2.5]),
397 np.array([4.0, 4.0]),
398 np.array([7.0, 7.0]),
399 ]
400
401 print("\n--- Point Containment Tests ---")
402 for p in test_points:
403 print(f" Point {p}:")
404 print(f" In Box1: {box1.contains_point(p)}")
405 print(f" In Box2: {box2.contains_point(p)}")
406
407 # Merge boxes
408 merged = merge_boxes([box1, box2]) # Takes a list of boxes
409 print("\n--- Merged Box (Box1 + Box2) ---")
410 print(f" Min: {merged.min_coords}")
411 print(f" Max: {merged.max_coords}")
412
413 # Create box from points
414 points = np.array([[1, 2], [5, 3], [2, 8], [7, 4]])
415 bbox = box_from_points(points)
416 print("\n--- Bounding Box of Points ---")
417 print(f" Points:\n{points}")
418 print(f" Bounding box: ({bbox.min_coords} to {bbox.max_coords})")
419
420
421def demo_vptree():
422 """Demonstrate VP-Tree for metric space indexing."""
423 print("\n" + "=" * 70)
424 print("VP-Tree Demo")
425 print("=" * 70)
426
427 np.random.seed(42)
428
429 # Generate points
430 n_points = 200
431 points = np.random.randn(n_points, 3) * 5
432
433 print(f"\nBuilding VP-Tree with {n_points} 3D points")
434
435 tree = VPTree(points)
436
437 # Query
438 query = np.array([1.0, 1.0, 1.0])
439 k = 5
440
441 result = tree.query(query, k=k)
442
443 print(f"\nQuery point: {query}")
444 print(f"{k} nearest neighbors:")
445 for i, (idx, dist) in enumerate(zip(result.indices[0], result.distances[0])):
446 print(f" {i+1}. Point {idx}: distance = {dist:.4f}")
447
448 print("\nNote: VP-Tree works with any distance metric,")
449 print("not just Euclidean distance.")
450
451
452def demo_covertree():
453 """Demonstrate Cover Tree for approximate nearest neighbor."""
454 print("\n" + "=" * 70)
455 print("Cover Tree Demo")
456 print("=" * 70)
457
458 np.random.seed(42)
459
460 # Generate points
461 n_points = 300
462 points = np.random.randn(n_points, 4) * 3 # 4D
463
464 print(f"\nBuilding Cover Tree with {n_points} 4D points")
465
466 tree = CoverTree(points)
467
468 # Query
469 query = np.zeros(4)
470 k = 5
471
472 result = tree.query(query, k=k)
473
474 print(f"\nQuery: origin in 4D")
475 print(f"{k} nearest neighbors:")
476 for i, (idx, dist) in enumerate(zip(result.indices[0], result.distances[0])):
477 print(f" {i+1}. Point {idx}: distance = {dist:.4f}")
478
479 print("\nNote: Cover Tree provides O(c^12 log n) query complexity")
480 print("where c is the expansion constant of the data.")
481
482
483def demo_performance_comparison():
484 """Compare performance of different spatial data structures."""
485 print("\n" + "=" * 70)
486 print("Performance Comparison Demo")
487 print("=" * 70)
488
489 import time
490
491 np.random.seed(42)
492
493 # Test data
494 n_points = 5000
495 n_queries = 100
496 dims = 3
497 k = 10
498
499 points = np.random.randn(n_points, dims) * 10
500 queries = np.random.randn(n_queries, dims) * 10
501
502 print(f"\nDataset: {n_points} points in {dims}D")
503 print(f"Queries: {n_queries} k-NN queries (k={k})")
504
505 results = {}
506
507 # K-D Tree
508 t0 = time.time()
509 kdtree = KDTree(points)
510 build_time = time.time() - t0
511
512 t0 = time.time()
513 for q in queries:
514 kdtree.query(q, k=k)
515 query_time = time.time() - t0
516
517 results["K-D Tree"] = (build_time, query_time)
518
519 # Ball Tree
520 t0 = time.time()
521 balltree = BallTree(points)
522 build_time = time.time() - t0
523
524 t0 = time.time()
525 for q in queries:
526 balltree.query(q, k=k)
527 query_time = time.time() - t0
528
529 results["Ball Tree"] = (build_time, query_time)
530
531 # VP Tree
532 t0 = time.time()
533 vptree = VPTree(points)
534 build_time = time.time() - t0
535
536 t0 = time.time()
537 for q in queries:
538 vptree.query(q, k=k)
539 query_time = time.time() - t0
540
541 results["VP-Tree"] = (build_time, query_time)
542
543 # Cover Tree
544 t0 = time.time()
545 covertree = CoverTree(points)
546 build_time = time.time() - t0
547
548 t0 = time.time()
549 for q in queries:
550 covertree.query(q, k=k)
551 query_time = time.time() - t0
552
553 results["Cover Tree"] = (build_time, query_time)
554
555 # Print results
556 print("\n" + "-" * 50)
557 print(f"{'Structure':<15} {'Build (ms)':>12} {'Query (ms)':>12}")
558 print("-" * 50)
559 for name, (build, query) in results.items():
560 print(f"{name:<15} {build*1000:>12.2f} {query*1000:>12.2f}")
561
562 print("\nNote: Performance depends on data distribution and dimensionality.")
563
564
565def demo_tracking_application():
566 """Demonstrate spatial indexing in tracking context."""
567 print("\n" + "=" * 70)
568 print("Tracking Application Demo")
569 print("=" * 70)
570
571 np.random.seed(42)
572
573 # Simulated scenario: sensor provides measurements,
574 # need to associate with predicted track positions
575
576 # Track predictions
577 n_tracks = 20
578 track_positions = np.random.uniform(-100, 100, (n_tracks, 2))
579
580 # Measurements (some from tracks, some false alarms)
581 n_measurements = 30
582 # First n_tracks measurements near track positions
583 measurements = np.zeros((n_measurements, 2))
584 for i in range(min(n_tracks, n_measurements)):
585 measurements[i] = track_positions[i] + np.random.randn(2) * 2.0
586
587 # Remaining are false alarms
588 for i in range(n_tracks, n_measurements):
589 measurements[i] = np.random.uniform(-100, 100, 2)
590
591 print(f"\n{n_tracks} track predictions")
592 print(
593 f"{n_measurements} measurements ({n_tracks} true + "
594 f"{n_measurements - n_tracks} false alarms)"
595 )
596
597 # Build spatial index on track predictions
598 tree = KDTree(track_positions)
599
600 # For each measurement, find nearest track
601 print("\nMeasurement-to-track association using K-D tree:")
602 print("-" * 50)
603
604 gating_threshold = 5.0 # meters
605 associations = []
606
607 for m_idx, meas in enumerate(measurements):
608 result = tree.query(meas, k=1)
609 nearest_track = result.indices[0, 0] # 2D array [query_idx, neighbor_idx]
610 distance = result.distances[0, 0]
611
612 if distance < gating_threshold:
613 associations.append((m_idx, nearest_track, distance))
614
615 print(f"Gating threshold: {gating_threshold} m")
616 print(f"Measurements passing gate: {len(associations)}/{n_measurements}")
617
618 # Show some associations
619 print("\nFirst 5 associations:")
620 for m_idx, t_idx, dist in associations[:5]:
621 true_assoc = m_idx == t_idx # Simplified ground truth
622 status = "+" if true_assoc else "?"
623 print(f" Meas {m_idx:>2} -> Track {t_idx:>2} " f"(dist={dist:.2f}) {status}")
624
625 # Radius query for gating
626 print("\n--- Using Radius Query for Gating ---")
627 meas_test = measurements[0]
628 result = tree.query_radius(meas_test, gating_threshold)
629 # query_radius returns list of index arrays (one per query point)
630 print(f"Measurement 0: {len(result[0])} tracks within gate")
631
632
633def main():
634 """Run all demonstrations."""
635 print("\n" + "#" * 70)
636 print("# PyTCL Spatial Data Structures Example")
637 print("#" * 70)
638
639 demo_kdtree_basics()
640 demo_kdtree_queries()
641 demo_balltree()
642 demo_rtree()
643 demo_bounding_box_operations()
644 demo_vptree()
645 demo_covertree()
646 demo_performance_comparison()
647 demo_tracking_application()
648
649 print("\n" + "=" * 70)
650 print("Example complete!")
651 if SHOW_PLOTS:
652 print("Plots saved: spatial_kdtree.html, spatial_rtree.html")
653 print("=" * 70)
654
655
656if __name__ == "__main__":
657 main()
Running the Example
python examples/spatial_data_structures.py
See Also
Gaussian Mixtures and Clustering - Clustering with spatial trees
multi_target_tracking - Gating with spatial queries