Source code for pytcl.containers.kd_tree

"""
K-D Tree implementation.

A k-dimensional tree (k-d tree) is a space-partitioning data structure
for organizing points in k-dimensional space. Useful for efficient
nearest neighbor searches in tracking applications.

References
----------
.. [1] J. L. Bentley, "Multidimensional binary search trees used for
       associative searching," Communications of the ACM, 1975.
.. [2] J. H. Friedman, J. L. Bentley, R. A. Finkel, "An Algorithm for
       Finding Best Matches in Logarithmic Expected Time," ACM TOMS, 1977.
"""

import logging
from typing import List, Optional, Tuple

import numpy as np
from numpy.typing import ArrayLike, NDArray

from pytcl.containers.base import NearestNeighborResult  # Backward compatibility alias
from pytcl.containers.base import (
    BaseSpatialIndex,
    NeighborResult,
    validate_query_input,
)

# Module logger
_logger = logging.getLogger("pytcl.containers.kd_tree")


[docs] class KDNode: """A node in the k-d tree. Attributes ---------- point : ndarray The point stored at this node. index : int Original index of this point in the data array. split_dim : int The dimension used for splitting at this node. left : KDNode or None Left child (points with smaller split dimension value). right : KDNode or None Right child (points with larger split dimension value). """ __slots__ = ["point", "index", "split_dim", "left", "right"]
[docs] def __init__( self, point: NDArray[np.floating], index: int, split_dim: int, ): self.point = point self.index = index self.split_dim = split_dim self.left: Optional["KDNode"] = None self.right: Optional["KDNode"] = None
[docs] class KDTree(BaseSpatialIndex): """ K-D Tree for efficient spatial queries. A k-d tree recursively partitions space by splitting along different dimensions at each level. This enables O(log n) average-case nearest neighbor queries. Parameters ---------- data : array_like Data points of shape (n_samples, n_features). leaf_size : int, optional Maximum number of points in a leaf node. Default 10. Examples -------- >>> import numpy as np >>> points = np.array([[0, 0], [1, 0], [0, 1], [1, 1]]) >>> tree = KDTree(points) >>> result = tree.query(np.array([[0.1, 0.1]]), k=2) >>> result.indices array([[0, 1]]) Notes ----- The tree is balanced by choosing the median point along the split dimension at each level, giving O(n log n) construction time. Query complexity is O(log n) on average for nearest neighbor search, though worst case is O(n) for highly unbalanced queries. See Also -------- BaseSpatialIndex : Abstract base class defining the spatial index interface. BallTree : Alternative spatial index using hyperspheres. """
[docs] def __init__( self, data: ArrayLike, leaf_size: int = 10, ): super().__init__(data) self.leaf_size = leaf_size # Build the tree indices = np.arange(self.n_samples) self.root = self._build_tree(indices, depth=0) _logger.debug("KDTree built with leaf_size=%d", leaf_size)
def _build_tree( self, indices: NDArray[np.intp], depth: int, ) -> Optional[KDNode]: """Recursively build the k-d tree.""" if len(indices) == 0: return None # Choose split dimension (cycle through dimensions) split_dim = depth % self.n_features # Sort indices by split dimension and find median sorted_indices = indices[np.argsort(self.data[indices, split_dim])] median_idx = len(sorted_indices) // 2 # Create node node_index = sorted_indices[median_idx] node = KDNode( point=self.data[node_index], index=node_index, split_dim=split_dim, ) # Recursively build subtrees node.left = self._build_tree(sorted_indices[:median_idx], depth + 1) node.right = self._build_tree(sorted_indices[median_idx + 1 :], depth + 1) return node
[docs] def query( self, X: ArrayLike, k: int = 1, ) -> NeighborResult: """ Query the tree for k nearest neighbors. Parameters ---------- X : array_like Query points of shape (n_queries, n_features) or (n_features,). k : int, optional Number of nearest neighbors. Default 1. Returns ------- result : NeighborResult Indices and distances of k nearest neighbors for each query. Examples -------- >>> tree = KDTree(np.array([[0, 0], [1, 1], [2, 2]])) >>> result = tree.query([[0.5, 0.5]], k=2) >>> result.indices array([[0, 1]]) """ X = validate_query_input(X, self.n_features) n_queries = X.shape[0] _logger.debug("KDTree.query: %d queries, k=%d", n_queries, k) all_indices = np.zeros((n_queries, k), dtype=np.intp) all_distances = np.full((n_queries, k), np.inf) for i in range(n_queries): neighbors = self._query_single(X[i], k) n_found = len(neighbors) if n_found > 0: indices, distances = zip(*neighbors) all_indices[i, :n_found] = indices all_distances[i, :n_found] = distances return NeighborResult(indices=all_indices, distances=all_distances)
def _query_single( self, query: NDArray[np.floating], k: int, ) -> List[Tuple[int, float]]: """Find k nearest neighbors for a single query point.""" # List of (index, distance) tuples, maintained as a bounded heap neighbors: List[Tuple[int, float]] = [] def _search(node: Optional[KDNode], depth: int) -> None: if node is None: return # Distance to current node dist = np.sqrt(np.sum((query - node.point) ** 2)) # Add to neighbors if room or better than worst if len(neighbors) < k: neighbors.append((node.index, dist)) neighbors.sort(key=lambda x: x[1]) elif dist < neighbors[-1][1]: neighbors[-1] = (node.index, dist) neighbors.sort(key=lambda x: x[1]) # Decide which subtree to search first split_dim = node.split_dim diff = query[split_dim] - node.point[split_dim] if diff <= 0: first, second = node.left, node.right else: first, second = node.right, node.left # Search closer subtree first _search(first, depth + 1) # Check if we need to search the other subtree # Only if the splitting plane is closer than our worst neighbor if len(neighbors) < k or abs(diff) < neighbors[-1][1]: _search(second, depth + 1) _search(self.root, 0) return neighbors
[docs] def query_radius( self, X: ArrayLike, r: float, ) -> List[List[int]]: """ Query the tree for all points within radius r. Parameters ---------- X : array_like Query points of shape (n_queries, n_features) or (n_features,). r : float Query radius. Returns ------- indices : list of lists For each query, a list of indices of points within radius r. Examples -------- >>> tree = KDTree(np.array([[0, 0], [1, 0], [0, 1], [5, 5]])) >>> tree.query_radius([[0, 0]], r=1.5) [[0, 1, 2]] """ X = validate_query_input(X, self.n_features) n_queries = X.shape[0] results: List[List[int]] = [] for i in range(n_queries): indices = self._query_radius_single(X[i], r) results.append(indices) return results
def _query_radius_single( self, query: NDArray[np.floating], r: float, ) -> List[int]: """Find all points within radius r of query point.""" indices: List[int] = [] def _search(node: Optional[KDNode]) -> None: if node is None: return # Distance to current node dist = np.sqrt(np.sum((query - node.point) ** 2)) if dist <= r: indices.append(node.index) # Check if subtrees might contain points in radius split_dim = node.split_dim diff = query[split_dim] - node.point[split_dim] # Search appropriate subtrees if diff - r <= 0: _search(node.left) if diff + r >= 0: _search(node.right) _search(self.root) return indices
# query_ball_point inherited from BaseSpatialIndex
[docs] class BallTree(BaseSpatialIndex): """ Ball Tree for efficient spatial queries. A ball tree partitions space using hyperspheres (balls), which can be more efficient than k-d trees for high-dimensional data or non-Euclidean metrics. Parameters ---------- data : array_like Data points of shape (n_samples, n_features). leaf_size : int, optional Maximum number of points in a leaf node. Default 10. Examples -------- >>> import numpy as np >>> points = np.array([[0, 0], [1, 0], [0, 1], [1, 1]]) >>> tree = BallTree(points) >>> result = tree.query(np.array([[0.1, 0.1]]), k=2) Notes ----- Ball trees have O(n log n) construction and O(log n) average-case query time. They can outperform k-d trees in high dimensions. See Also -------- BaseSpatialIndex : Abstract base class defining the spatial index interface. KDTree : Alternative spatial index using axis-aligned splits. """
[docs] def __init__( self, data: ArrayLike, leaf_size: int = 10, ): super().__init__(data) self.leaf_size = leaf_size # Build tree using indices self._indices = np.arange(self.n_samples) self._centroids: List[NDArray[np.floating]] = [] self._radii: List[float] = [] self._left: List[int] = [] self._right: List[int] = [] self._is_leaf: List[bool] = [] self._leaf_indices: List[Optional[NDArray[np.intp]]] = [] self._build_tree(self._indices) _logger.debug("BallTree built with leaf_size=%d", leaf_size)
def _build_tree( self, indices: NDArray[np.intp], ) -> int: """Build the ball tree recursively.""" node_id = len(self._centroids) points = self.data[indices] centroid = np.mean(points, axis=0) radius = float(np.max(np.sqrt(np.sum((points - centroid) ** 2, axis=1)))) self._centroids.append(centroid) self._radii.append(radius) if len(indices) <= self.leaf_size: # Leaf node self._is_leaf.append(True) self._leaf_indices.append(indices.copy()) self._left.append(-1) self._right.append(-1) else: # Internal node - split along dimension with largest spread spread = np.max(points, axis=0) - np.min(points, axis=0) split_dim = np.argmax(spread) # Split at median split_values = self.data[indices, split_dim] median = np.median(split_values) left_mask = split_values <= median right_mask = ~left_mask # Ensure non-empty splits if not np.any(left_mask) or not np.any(right_mask): mid = len(indices) // 2 sorted_idx = indices[np.argsort(split_values)] left_indices = sorted_idx[:mid] right_indices = sorted_idx[mid:] else: left_indices = indices[left_mask] right_indices = indices[right_mask] self._is_leaf.append(False) self._leaf_indices.append(None) self._left.append(-1) # Placeholder self._right.append(-1) # Placeholder # Build subtrees left_id = self._build_tree(left_indices) right_id = self._build_tree(right_indices) self._left[node_id] = left_id self._right[node_id] = right_id return node_id
[docs] def query( self, X: ArrayLike, k: int = 1, ) -> NeighborResult: """ Query the tree for k nearest neighbors. Parameters ---------- X : array_like Query points of shape (n_queries, n_features) or (n_features,). k : int, optional Number of nearest neighbors. Default 1. Returns ------- result : NeighborResult Indices and distances of k nearest neighbors. """ X = validate_query_input(X, self.n_features) n_queries = X.shape[0] all_indices = np.zeros((n_queries, k), dtype=np.intp) all_distances = np.full((n_queries, k), np.inf) for i in range(n_queries): neighbors = self._query_single(X[i], k) n_found = len(neighbors) if n_found > 0: indices, distances = zip(*neighbors) all_indices[i, :n_found] = indices all_distances[i, :n_found] = distances return NeighborResult(indices=all_indices, distances=all_distances)
# query_ball_point inherited from BaseSpatialIndex
[docs] def query_radius( self, X: ArrayLike, r: float, ) -> List[List[int]]: """ Query the tree for all points within radius r. Parameters ---------- X : array_like Query points of shape (n_queries, n_features) or (n_features,). r : float Query radius. Returns ------- indices : list of lists For each query, a list of indices of points within radius r. """ X = validate_query_input(X, self.n_features) n_queries = X.shape[0] results: List[List[int]] = [] for i in range(n_queries): indices = self._query_radius_single(X[i], r) results.append(indices) return results
def _query_radius_single( self, query: NDArray[np.floating], r: float, ) -> List[int]: """Find all points within radius r of query point.""" indices: List[int] = [] def _search(node_id: int) -> None: if node_id < 0: return centroid = self._centroids[node_id] radius = self._radii[node_id] # Distance to ball surface dist_to_center = np.sqrt(np.sum((query - centroid) ** 2)) # Prune if ball is farther than radius if dist_to_center - radius > r: return if self._is_leaf[node_id]: # Check all points in leaf leaf_indices = self._leaf_indices[node_id] if leaf_indices is not None: for idx in leaf_indices: dist = np.sqrt(np.sum((query - self.data[idx]) ** 2)) if dist <= r: indices.append(idx) else: # Visit both children _search(self._left[node_id]) _search(self._right[node_id]) _search(0) return indices def _query_single( self, query: NDArray[np.floating], k: int, ) -> List[Tuple[int, float]]: """Find k nearest neighbors for a single point.""" neighbors: List[Tuple[int, float]] = [] def _search(node_id: int) -> None: if node_id < 0: return centroid = self._centroids[node_id] radius = self._radii[node_id] # Distance to ball surface dist_to_center = np.sqrt(np.sum((query - centroid) ** 2)) # Prune if ball is farther than current worst neighbor if len(neighbors) >= k and dist_to_center - radius >= neighbors[-1][1]: return if self._is_leaf[node_id]: # Check all points in leaf leaf_indices = self._leaf_indices[node_id] if leaf_indices is not None: for idx in leaf_indices: dist = np.sqrt(np.sum((query - self.data[idx]) ** 2)) if len(neighbors) < k: neighbors.append((idx, dist)) neighbors.sort(key=lambda x: x[1]) elif dist < neighbors[-1][1]: neighbors[-1] = (idx, dist) neighbors.sort(key=lambda x: x[1]) else: # Visit closer child first left_id = self._left[node_id] right_id = self._right[node_id] left_dist = ( np.sqrt(np.sum((query - self._centroids[left_id]) ** 2)) if left_id >= 0 else np.inf ) right_dist = ( np.sqrt(np.sum((query - self._centroids[right_id]) ** 2)) if right_id >= 0 else np.inf ) if left_dist <= right_dist: _search(left_id) _search(right_id) else: _search(right_id) _search(left_id) _search(0) return neighbors
__all__ = [ "KDNode", "NeighborResult", "NearestNeighborResult", # Backward compatibility alias "KDTree", "BallTree", ]