Source code for pytcl.containers.covertree

"""
Cover Tree implementation for nearest neighbor search.

Cover trees are data structures for nearest neighbor search in metric
spaces with a theoretical guarantee of O(c^12 log n) query time, where
c is the expansion constant of the data.

References
----------
.. [1] A. Beygelzimer, S. Kakade, J. Langford, "Cover trees for nearest
       neighbor," ICML 2006.
"""

import logging
from typing import Any, Callable, List, Optional, Set, Tuple

import numpy as np
from numpy.typing import ArrayLike, NDArray

from pytcl.containers.base import CoverTreeResult  # Backward compatibility alias
from pytcl.containers.base import (
    MetricSpatialIndex,
    NeighborResult,
    validate_query_input,
)

# Module logger
_logger = logging.getLogger("pytcl.containers.covertree")


[docs] class CoverTreeNode: """Node in a Cover tree. Attributes ---------- index : int Index of the point in the original data. level : int Level in the tree (determines covering radius 2^level). children : dict Children organized by level. """ __slots__ = ["index", "level", "children"]
[docs] def __init__(self, index: int, level: int): self.index = index self.level = level # Children at each level self.children: dict[int, List["CoverTreeNode"]] = {}
[docs] def add_child(self, level: int, child: "CoverTreeNode") -> None: """Add a child at the specified level.""" if level not in self.children: self.children[level] = [] self.children[level].append(child)
[docs] class CoverTree(MetricSpatialIndex): """ Cover Tree for metric space nearest neighbor search. A cover tree maintains a hierarchy of nested coverings of the data, where points at level i are a subset of points at level i-1 and cover all points within distance 2^i. Parameters ---------- data : array_like Data points of shape (n_samples, n_features). metric : callable, optional Distance function metric(x, y) -> float. Default is Euclidean distance. base : float, optional Base for the exponential scale. Default 2.0. Examples -------- >>> import numpy as np >>> points = np.random.rand(100, 3) >>> tree = CoverTree(points) >>> result = tree.query(points[:5], k=3) Notes ----- Cover trees provide theoretical guarantees based on the expansion constant of the data. For well-distributed data, queries are efficient even in high dimensions. The implementation uses a simplified version of the original algorithm for clarity. See Also -------- MetricSpatialIndex : Abstract base class for metric-based spatial indices. VPTree : Alternative metric space index using vantage points. """
[docs] def __init__( self, data: ArrayLike, metric: Optional[ Callable[[np.ndarray[Any, Any], np.ndarray[Any, Any]], float] ] = None, base: float = 2.0, ): super().__init__(data, metric) self.base = base # Compute distance cache for small datasets self._distance_cache: dict[Tuple[int, int], float] = {} # Build tree self.root: Optional[CoverTreeNode] = None self.max_level = 0 self.min_level = 0 if self.n_samples > 0: self._build_tree() _logger.debug( "CoverTree built with base=%.1f, levels=%d to %d", base, self.min_level, self.max_level, )
def _distance(self, i: int, j: int) -> float: """Get distance between points i and j (with caching).""" if i == j: return 0.0 key = (min(i, j), max(i, j)) if key not in self._distance_cache: self._distance_cache[key] = self.metric(self.data[i], self.data[j]) return self._distance_cache[key] def _distance_to_point(self, idx: int, query: NDArray[np.floating]) -> float: """Distance from data point to query point.""" return self.metric(self.data[idx], query) def _cover_distance(self, level: int) -> float: """Get the cover distance for a level (base^level).""" return self.base**level def _build_tree(self) -> None: """Build the cover tree using batch insertion.""" # Find max distance to set initial level max_dist = 0.0 for i in range(min(self.n_samples, 100)): # Sample for large datasets for j in range(i + 1, min(self.n_samples, 100)): d = self._distance(i, j) max_dist = max(max_dist, d) # Set initial level if max_dist > 0: self.max_level = int(np.ceil(np.log(max_dist) / np.log(self.base))) + 1 else: self.max_level = 0 self.min_level = self.max_level # Create root with first point self.root = CoverTreeNode(0, self.max_level) # Insert remaining points for i in range(1, self.n_samples): self._insert(i) def _insert(self, point_idx: int) -> None: """Insert a point into the cover tree.""" if self.root is None: self.root = CoverTreeNode(point_idx, self.max_level) return # Find the level at which to insert # Start from max_level and descend level = self.max_level # Find nodes at each level that cover this point cover_sets: dict[int, List[CoverTreeNode]] = {level: [self.root]} while level > self.min_level - 1: cover_dist = self._cover_distance(level) next_level = level - 1 next_cover: List[CoverTreeNode] = [] for node in cover_sets.get(level, []): # Check if this node covers the new point d = self._distance(node.index, point_idx) if d <= cover_dist: # Node covers point, add to candidates for next level next_cover.append(node) # Also add children as candidates for child in node.children.get(next_level, []): if self._distance(child.index, point_idx) <= cover_dist: next_cover.append(child) if not next_cover: # No nodes at next level cover this point # Insert here break cover_sets[next_level] = next_cover level = next_level # Insert point as child of closest covering node min_dist = np.inf parent = self.root for node in cover_sets.get(level, [self.root]): d = self._distance(node.index, point_idx) if d < min_dist: min_dist = d parent = node # Create new node new_level = level - 1 new_node = CoverTreeNode(point_idx, new_level) parent.add_child(new_level, new_node) # Update min level self.min_level = min(self.min_level, new_level)
[docs] def query( self, X: ArrayLike, k: int = 1, ) -> NeighborResult: """ Query the tree for k nearest neighbors. Parameters ---------- X : array_like Query points of shape (n_queries, n_features) or (n_features,). k : int, optional Number of nearest neighbors. Default 1. Returns ------- result : NeighborResult Indices and distances of k nearest neighbors. """ X = validate_query_input(X, self.n_features) n_queries = X.shape[0] all_indices = np.zeros((n_queries, k), dtype=np.intp) all_distances = np.full((n_queries, k), np.inf) for i in range(n_queries): neighbors = self._query_single(X[i], k) n_found = len(neighbors) if n_found > 0: indices, distances = zip(*neighbors) all_indices[i, :n_found] = indices all_distances[i, :n_found] = distances return NeighborResult(indices=all_indices, distances=all_distances)
# query_ball_point inherited from BaseSpatialIndex def _query_single( self, query: NDArray[np.floating], k: int, ) -> List[Tuple[int, float]]: """Find k nearest neighbors for a single query.""" if self.root is None: return [] neighbors: List[Tuple[int, float]] = [] # Queue of (node, level) pairs to explore # Start with root Q: Set[Tuple[int, int]] = {(self.root.index, self.max_level)} def get_nodes_at_level(indices: Set[int], level: int) -> List[CoverTreeNode]: """Get all nodes at a level given their indices.""" result = [] # This is a simplification - in practice we'd maintain node references # For now, search from root def find_nodes(node: CoverTreeNode, target_level: int) -> None: if node.index in indices and node.level >= target_level: result.append(node) for child_level, children in node.children.items(): for child in children: find_nodes(child, target_level) if self.root: find_nodes(self.root, level) return result level = self.max_level while level >= self.min_level and Q: # Compute distances to all points in Q Q_dist = [(idx, self._distance_to_point(idx, query)) for idx, _ in Q] # Update neighbors for idx, dist in Q_dist: if len(neighbors) < k: neighbors.append((idx, dist)) neighbors.sort(key=lambda x: x[1]) elif dist < neighbors[-1][1]: neighbors[-1] = (idx, dist) neighbors.sort(key=lambda x: x[1]) # Current radius bound if len(neighbors) >= k: tau = neighbors[-1][1] # Prune: keep only points within tau + 2^level of query cover_dist = self._cover_distance(level) Q_next: Set[Tuple[int, int]] = set() for idx, dist in Q_dist: if dist <= tau + cover_dist: Q_next.add((idx, level - 1)) # Find children of this node nodes = get_nodes_at_level({idx}, level) for node in nodes: for child in node.children.get(level - 1, []): child_dist = self._distance_to_point(child.index, query) if child_dist <= tau + cover_dist: Q_next.add((child.index, level - 1)) Q = Q_next else: # Haven't found k neighbors yet, expand all Q_next_expand: Set[Tuple[int, int]] = set() for idx, _ in Q: Q_next_expand.add((idx, level - 1)) nodes = get_nodes_at_level({idx}, level) for node in nodes: for child in node.children.get(level - 1, []): Q_next_expand.add((child.index, level - 1)) Q = Q_next_expand level -= 1 return neighbors
[docs] def query_radius( self, X: ArrayLike, r: float, ) -> List[List[int]]: """ Find all points within radius r of query points. Parameters ---------- X : array_like Query points. r : float Query radius. Returns ------- indices : list of lists For each query, list of indices within radius. """ X = validate_query_input(X, self.n_features) n_queries = X.shape[0] results: List[List[int]] = [] for i in range(n_queries): indices = self._query_radius_single(X[i], r) results.append(indices) return results
def _query_radius_single( self, query: NDArray[np.floating], r: float, ) -> List[int]: """Find all points within radius r of query.""" if self.root is None: return [] indices: List[int] = [] def search(node: CoverTreeNode, level: int) -> None: dist = self._distance_to_point(node.index, query) # Check if this point is within radius if dist <= r: indices.append(node.index) # Check if children could be within radius cover_dist = self._cover_distance(level) if dist <= r + cover_dist: # Search children at all levels for child_level, children in node.children.items(): for child in children: search(child, child_level) search(self.root, self.max_level) return indices
__all__ = [ "NeighborResult", "CoverTreeResult", # Backward compatibility alias "CoverTreeNode", "CoverTree", ]