Source code for pytcl.containers.vptree

"""
Vantage Point Tree (VP-tree) implementation.

VP-trees are metric trees that partition data based on distance to
selected vantage points. They are effective for nearest neighbor
search in metric spaces, particularly with high-dimensional data.

References
----------
.. [1] P. N. Yianilos, "Data structures and algorithms for nearest
       neighbor search in general metric spaces," SODA 1993.
"""

import logging
from typing import Any, Callable, List, Optional, Tuple

import numpy as np
from numpy.typing import ArrayLike, NDArray

from pytcl.containers.base import VPTreeResult  # Backward compatibility alias
from pytcl.containers.base import (
    MetricSpatialIndex,
    NeighborResult,
    validate_query_input,
)

# Module logger
_logger = logging.getLogger("pytcl.containers.vptree")


[docs] class VPNode: """Node in a VP-tree. Attributes ---------- index : int Index of the vantage point in the original data. radius : float Median distance to vantage point (splitting threshold). left : VPNode or None Left subtree (points closer than radius). right : VPNode or None Right subtree (points farther than radius). """ __slots__ = ["index", "radius", "left", "right"]
[docs] def __init__(self, index: int, radius: float = 0.0): self.index = index self.radius = radius self.left: Optional["VPNode"] = None self.right: Optional["VPNode"] = None
[docs] class VPTree(MetricSpatialIndex): """ Vantage Point Tree for metric space nearest neighbor search. A VP-tree recursively partitions space by selecting a vantage point and dividing remaining points into those closer than a threshold (median distance) and those farther. Parameters ---------- data : array_like Data points of shape (n_samples, n_features). metric : callable, optional Distance function metric(x, y) -> float. Default is Euclidean distance. Examples -------- >>> import numpy as np >>> points = np.random.rand(100, 3) >>> tree = VPTree(points) >>> result = tree.query(points[:5], k=3) >>> result.indices.shape (5, 3) Notes ----- VP-trees can use any metric distance function, making them useful for non-Euclidean spaces (e.g., edit distance for strings, geodesic distance on manifolds). Query complexity is O(log n) on average but can degrade to O(n) for pathological distance distributions. See Also -------- MetricSpatialIndex : Abstract base class for metric-based spatial indices. CoverTree : Alternative metric space index with theoretical guarantees. """
[docs] def __init__( self, data: ArrayLike, metric: Optional[ Callable[[np.ndarray[Any, Any], np.ndarray[Any, Any]], float] ] = None, ): super().__init__(data, metric) # Build tree indices = np.arange(self.n_samples) self.root = self._build_tree(indices) metric_name = metric.__name__ if metric else "euclidean" _logger.debug("VPTree built with metric=%s", metric_name)
def _build_tree(self, indices: NDArray[np.intp]) -> Optional[VPNode]: """Recursively build the VP-tree.""" if len(indices) == 0: return None if len(indices) == 1: return VPNode(indices[0], 0.0) # Select vantage point (use first point for simplicity) # Better strategies: random selection, spread-based selection vp_idx = indices[0] vp = self.data[vp_idx] # Compute distances to vantage point remaining = indices[1:] distances = np.array([self.metric(vp, self.data[i]) for i in remaining]) # Split at median distance median_dist = float(np.median(distances)) # Partition into left (closer) and right (farther) left_mask = distances <= median_dist right_mask = ~left_mask left_indices = remaining[left_mask] right_indices = remaining[right_mask] node = VPNode(vp_idx, median_dist) node.left = self._build_tree(left_indices) node.right = self._build_tree(right_indices) return node
[docs] def query( self, X: ArrayLike, k: int = 1, ) -> NeighborResult: """ Query the tree for k nearest neighbors. Parameters ---------- X : array_like Query points of shape (n_queries, n_features) or (n_features,). k : int, optional Number of nearest neighbors. Default 1. Returns ------- result : NeighborResult Indices and distances of k nearest neighbors. """ X = validate_query_input(X, self.n_features) n_queries = X.shape[0] all_indices = np.zeros((n_queries, k), dtype=np.intp) all_distances = np.full((n_queries, k), np.inf) for i in range(n_queries): neighbors = self._query_single(X[i], k) n_found = len(neighbors) if n_found > 0: indices, distances = zip(*neighbors) all_indices[i, :n_found] = indices all_distances[i, :n_found] = distances return NeighborResult(indices=all_indices, distances=all_distances)
# query_ball_point inherited from BaseSpatialIndex def _query_single( self, query: NDArray[np.floating], k: int, ) -> List[Tuple[int, float]]: """Find k nearest neighbors for a single query.""" # List of (index, distance) tuples, maintained sorted neighbors: List[Tuple[int, float]] = [] tau = np.inf # Current kth nearest distance def search(node: Optional[VPNode]) -> None: nonlocal tau if node is None: return # Distance to vantage point dist = self.metric(query, self.data[node.index]) # Check if vantage point is a neighbor if dist < tau: if len(neighbors) < k: neighbors.append((node.index, dist)) neighbors.sort(key=lambda x: x[1]) if len(neighbors) == k: tau = neighbors[-1][1] else: neighbors[-1] = (node.index, dist) neighbors.sort(key=lambda x: x[1]) tau = neighbors[-1][1] # Decide which subtrees to search if dist < node.radius: # Query is closer to vantage point than radius # Search left first (closer points) if dist - tau <= node.radius: search(node.left) if dist + tau >= node.radius: search(node.right) else: # Query is farther from vantage point than radius # Search right first (farther points) if dist + tau >= node.radius: search(node.right) if dist - tau <= node.radius: search(node.left) search(self.root) return neighbors
[docs] def query_radius( self, X: ArrayLike, r: float, ) -> List[List[int]]: """ Find all points within radius r of query points. Parameters ---------- X : array_like Query points. r : float Query radius. Returns ------- indices : list of lists For each query, list of indices within radius. """ X = validate_query_input(X, self.n_features) n_queries = X.shape[0] results: List[List[int]] = [] for i in range(n_queries): indices = self._query_radius_single(X[i], r) results.append(indices) return results
def _query_radius_single( self, query: NDArray[np.floating], r: float, ) -> List[int]: """Find all points within radius r of query.""" indices: List[int] = [] def search(node: Optional[VPNode]) -> None: if node is None: return dist = self.metric(query, self.data[node.index]) # Check vantage point if dist <= r: indices.append(node.index) # Check subtrees if dist - r <= node.radius: search(node.left) if dist + r >= node.radius: search(node.right) search(self.root) return indices
__all__ = [ "NeighborResult", "VPTreeResult", # Backward compatibility alias "VPNode", "VPTree", ]