"""
Base classes for spatial data structures.
This module provides abstract base classes that define the common interface
for spatial indexing data structures like KD-trees, VP-trees, R-trees, and
Cover trees.
The unified interface ensures all spatial indices provide consistent:
- Constructor patterns (data, optional parameters)
- Query methods (query, query_radius, query_ball_point)
- Return types (NeighborResult)
"""
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, List, NamedTuple, Optional
import numpy as np
from numpy.typing import ArrayLike, NDArray
# Module logger
_logger = logging.getLogger("pytcl.containers")
[docs]
class NeighborResult(NamedTuple):
"""
Unified result type for spatial index queries.
All spatial index implementations (KDTree, BallTree, VPTree, CoverTree,
RTree) return this type from their query methods, ensuring a consistent
interface across the library.
Attributes
----------
indices : ndarray of shape (n_queries, k) or (n_queries,)
Indices of the k nearest neighbors in the original data array.
For k=1, may be 1D. For k>1, shape is (n_queries, k).
distances : ndarray of shape (n_queries, k) or (n_queries,)
Distances to the k nearest neighbors.
Same shape as indices.
Examples
--------
>>> from pytcl.containers import KDTree
>>> import numpy as np
>>> points = np.array([[0, 0], [1, 0], [0, 1], [1, 1]])
>>> tree = KDTree(points)
>>> result = tree.query([[0.1, 0.1]], k=2)
>>> result.indices
array([[0, 2]])
>>> result.distances
array([[0.14142136, 0.9 ]])
See Also
--------
BaseSpatialIndex : Abstract base class for spatial indices.
"""
indices: NDArray[np.intp]
distances: NDArray[np.floating]
# Backward compatibility aliases - all map to NeighborResult
SpatialQueryResult = NeighborResult
NearestNeighborResult = NeighborResult
VPTreeResult = NeighborResult
CoverTreeResult = NeighborResult
RTreeResult = NeighborResult
[docs]
class BaseSpatialIndex(ABC):
"""
Abstract base class for spatial indexing data structures.
All spatial index implementations (KDTree, VPTree, RTree, CoverTree)
should inherit from this class and implement the required methods.
This provides a consistent interface for:
- Building the index from point data
- k-nearest neighbor queries
- Range/radius queries
- Dimension and size introspection
Parameters
----------
data : array_like
Data points of shape (n_samples, n_features).
Attributes
----------
data : ndarray
The indexed data points.
n_samples : int
Number of data points.
n_features : int
Dimensionality of data points.
"""
[docs]
def __init__(self, data: ArrayLike):
self.data = np.asarray(data, dtype=np.float64)
if self.data.ndim != 2:
raise ValueError(
f"Data must be 2-dimensional (n_samples, n_features), "
f"got shape {self.data.shape}"
)
self.n_samples, self.n_features = self.data.shape
_logger.debug(
"%s initialized with %d points in %d dimensions",
self.__class__.__name__,
self.n_samples,
self.n_features,
)
[docs]
@abstractmethod
def query(
self,
X: ArrayLike,
k: int = 1,
) -> NeighborResult:
"""
Query the index for k nearest neighbors.
Parameters
----------
X : array_like
Query points of shape (n_queries, n_features) or (n_features,).
k : int, optional
Number of nearest neighbors to return. Default is 1.
Returns
-------
result : NeighborResult
Named tuple with indices and distances of k nearest neighbors
for each query point.
"""
pass
[docs]
@abstractmethod
def query_radius(
self,
X: ArrayLike,
r: float,
) -> List[List[int]]:
"""
Query the index for all points within radius r.
Parameters
----------
X : array_like
Query points of shape (n_queries, n_features) or (n_features,).
r : float
Search radius.
Returns
-------
indices : list of list of int
For each query point, a list of indices of data points
within distance r.
"""
pass
[docs]
def query_ball_point(
self,
X: ArrayLike,
r: float,
) -> List[List[int]]:
"""
Query the index for all points within radius r.
This is an alias for :meth:`query_radius` provided for compatibility
with scipy.spatial.KDTree.
Parameters
----------
X : array_like
Query points of shape (n_queries, n_features) or (n_features,).
r : float
Search radius.
Returns
-------
indices : list of list of int
For each query point, a list of indices of data points
within distance r.
See Also
--------
query_radius : The underlying implementation.
"""
return self.query_radius(X, r)
[docs]
def __len__(self) -> int:
"""Return number of indexed points."""
return self.n_samples
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"n_samples={self.n_samples}, n_features={self.n_features})"
)
[docs]
class MetricSpatialIndex(BaseSpatialIndex):
"""
Base class for metric space spatial indices.
Extends BaseSpatialIndex with support for custom distance metrics.
Used by VP-trees and Cover trees which can work with any metric.
Parameters
----------
data : array_like
Data points of shape (n_samples, n_features).
metric : callable, optional
Distance function with signature metric(x, y) -> float.
Default is Euclidean distance.
"""
[docs]
def __init__(
self,
data: ArrayLike,
metric: Optional[Callable[[NDArray[Any], NDArray[Any]], float]] = None,
):
super().__init__(data)
if metric is None:
self.metric = self._euclidean_distance
else:
self.metric = metric
@staticmethod
def _euclidean_distance(x: NDArray[Any], y: NDArray[Any]) -> float:
"""Default Euclidean distance metric."""
return float(np.sqrt(np.sum((x - y) ** 2)))
__all__ = [
# Primary types
"NeighborResult",
"BaseSpatialIndex",
"MetricSpatialIndex",
"validate_query_input",
# Backward compatibility aliases
"SpatialQueryResult",
"NearestNeighborResult",
"VPTreeResult",
"CoverTreeResult",
"RTreeResult",
]