Throughout this chapter, we've examined how fundamental data structures underpin efficient machine learning algorithms. Now, we put theory into practice by implementing a k-d tree (k-dimensional tree) from scratch in Python. K-d trees are particularly useful for accelerating nearest neighbor searches, a common operation in algorithms like k-Nearest Neighbors (k-NN), density estimation, and certain clustering methods. While libraries like Scikit-learn provide optimized implementations, building one yourself offers valuable insight into its mechanics and trade-offs.
A brute-force nearest neighbor search compares a query point to every other point in the dataset, resulting in O(n⋅k) time complexity for n points in k dimensions. For large datasets, this becomes computationally prohibitive. K-d trees aim to reduce this search time, often achieving average-case O(logn) complexity in low-dimensional spaces, by recursively partitioning the data space.
A k-d tree is a binary tree where each node represents an axis-aligned hyperplane splitting the space into two half-spaces. Points are stored implicitly within these regions or explicitly at the nodes.
Let's define a simple structure to represent a node in our tree. A namedtuple
is convenient for this:
import collections
import numpy as np
from typing import List, Tuple, Optional, Any
Point = np.ndarray # Assuming points are NumPy arrays
class Node(collections.namedtuple('Node', ['point', 'split_dim', 'left', 'right'])):
"""
Represents a node in the k-d tree.
Attributes:
point (Optional[Point]): The data point stored at this node (if not None).
For internal nodes, this is the splitting point.
split_dim (Optional[int]): The dimension used for splitting at this node.
left (Optional[Node]): The left child node.
right (Optional[Node]): The right child node.
"""
pass
# Make attributes optional for clarity, though split_dim and point
# will typically be set for internal nodes. NamedTuple defaults aren't
# easily mutable, so we handle Optional during creation logic.
Node.__new__.__defaults__ = (None,) * len(Node._fields)
The core of the k-d tree is the recursive building function. It takes a list of points and the current depth (to determine the splitting dimension).
def build_kdtree(points: List[Point], depth: int = 0) -> Optional[Node]:
"""
Recursively builds a k-d tree from a list of points.
Args:
points: A list of k-dimensional points (NumPy arrays).
depth: The current depth in the tree (used for cycling dimensions).
Returns:
The root node of the constructed k-d tree, or None if points is empty.
"""
if not points:
return None
k = len(points[0]) # Dimensionality of the data
split_dim = depth % k
# Sort points along the splitting dimension and choose the median
points.sort(key=lambda p: p[split_dim])
median_idx = len(points) // 2
median_point = points[median_idx]
# Recursively build left and right subtrees
left_subtree = build_kdtree(points[:median_idx], depth + 1)
right_subtree = build_kdtree(points[median_idx + 1:], depth + 1)
return Node(
point=median_point,
split_dim=split_dim,
left=left_subtree,
right=right_subtree
)
# Example Usage (assuming 2D points):
# points_list = [np.array([2, 3]), np.array([5, 4]), np.array([9, 6]),
# np.array([4, 7]), np.array([8, 1]), np.array([7, 2])]
# root_node = build_kdtree(points_list.copy()) # Pass a copy if original list needed
This implementation uses the median point to partition the data, aiming for a balanced tree. Sorting at each step gives an O(nlogn) complexity per level. Since there are logn levels, the total build time is O(k⋅nlog2n). This can be improved to O(k⋅nlogn) using a median-finding algorithm like introselect
(which numpy.partition
uses internally) instead of full sorting, but for simplicity, we use sorting here.
Finding the nearest neighbor involves traversing the tree efficiently, pruning branches that cannot contain a closer point than the best one found so far.
best_guess
.best_guess
, update best_guess
.best_distance
, intersects the splitting hyperplane defined by the current node. The distance from the query point to the hyperplane is simply the absolute difference in the split_dim
coordinate: abs(query_point[split_dim] - node.point[split_dim])
.best_distance
, the other subtree (the one not initially visited during descent) might contain a closer point, so recursively search that subtree as well. Otherwise, prune that branch.We'll use a helper function to manage the search state. Using a mutable object like a list or a custom class instance for best_guess
and best_distance
avoids issues with Python's handling of immutable types in recursion. We'll use a simple list [best_point, best_dist_sq]
to store the best point found and its squared distance (avoiding square roots until the very end is more efficient).
import math
def squared_distance(p1: Point, p2: Point) -> float:
"""Calculates the squared Euclidean distance between two points."""
return np.sum((p1 - p2)**2)
def find_nearest_neighbor_recursive(node: Optional[Node], query_point: Point, depth: int, best: List[Any]):
"""
Recursive helper function for nearest neighbor search.
Args:
node: The current node being visited.
query_point: The point for which to find the nearest neighbor.
depth: The current depth in the tree.
best: A list containing [best_point_found, best_squared_distance_found].
Updated in place.
"""
if node is None:
return
k = len(query_point)
split_dim = depth % k # Or node.split_dim if stored
# Calculate squared distance from query point to current node's point
dist_sq = squared_distance(query_point, node.point)
# Update best if current node is closer
if dist_sq < best[1]:
best[0] = node.point
best[1] = dist_sq
# Determine which subtree to explore first
diff = query_point[split_dim] - node.point[split_dim]
if diff < 0:
close_branch, away_branch = node.left, node.right
else:
close_branch, away_branch = node.right, node.left
# Recursively search the closer subtree first
find_nearest_neighbor_recursive(close_branch, query_point, depth + 1, best)
# Check if the 'away' branch needs to be explored (pruning step)
# Compare squared distance along the split dimension to best distance
if (diff**2) < best[1]:
find_nearest_neighbor_recursive(away_branch, query_point, depth + 1, best)
def find_nearest_neighbor(root: Optional[Node], query_point: Point) -> Tuple[Optional[Point], float]:
"""
Finds the nearest neighbor to a query point in the k-d tree.
Args:
root: The root node of the k-d tree.
query_point: The point for which to find the nearest neighbor.
Returns:
A tuple containing (nearest_point, distance), or (None, infinity) if the tree is empty.
"""
if root is None:
return None, float('inf')
# Initialize best with large distance
# We store squared distance for efficiency during search
best_state = [None, float('inf')]
find_nearest_neighbor_recursive(root, query_point, 0, best_state)
# Return the point and the actual Euclidean distance
return best_state[0], math.sqrt(best_state[1])
# Example Usage:
# query = np.array([6, 5])
# nearest, distance = find_nearest_neighbor(root_node, query)
# print(f"Query point: {query}")
# print(f"Nearest point: {nearest}, Distance: {distance:.4f}")
Adapting the search for the k nearest neighbors requires maintaining a list (or preferably, a bounded priority queue/max-heap) of the k best candidates found so far. The pruning condition changes: we only prune a branch if the distance from the query point to the splitting hyperplane is greater than the distance to the farthest point currently in our set of k candidates.
We can use Python's heapq
module, which implements a min-heap. To simulate a max-heap of size k for distances, we can store negative squared distances in the heap. When the heap size exceeds k, we remove the smallest negative squared distance (which corresponds to the largest actual squared distance).
import heapq
def find_k_nearest_neighbors(root: Optional[Node], query_point: Point, k: int) -> List[Tuple[float, Point]]:
"""
Finds the k nearest neighbors to a query point in the k-d tree.
Args:
root: The root node of the k-d tree.
query_point: The point for which to find the nearest neighbors.
k: The number of neighbors to find.
Returns:
A list of tuples (distance, point), sorted by distance.
Returns an empty list if the tree is empty or k <= 0.
"""
if root is None or k <= 0:
return []
# Use a min-heap storing (-squared_distance, point) to simulate a max-heap of distances
# This keeps the k points with the *smallest* distances (largest negative distances)
knn_heap: List[Tuple[float, Point]] = []
_find_knn_recursive(root, query_point, 0, k, knn_heap)
# Convert heap results to (distance, point) sorted list
result = [(-neg_dist_sq, point) for neg_dist_sq, point in knn_heap]
result.sort(key=lambda x: x[0]) # Sort by actual distance
# Calculate actual distances
final_result = [(math.sqrt(dist_sq), point) for dist_sq, point in result]
return final_result
def _find_knn_recursive(node: Optional[Node], query_point: Point, depth: int, k: int, knn_heap: List[Tuple[float, Point]]):
"""Recursive helper for k-NN search using a heap."""
if node is None:
return
num_dims = len(query_point)
split_dim = depth % num_dims
# Calculate squared distance to current node's point
dist_sq = squared_distance(query_point, node.point)
neg_dist_sq = -dist_sq
# Add current point to heap / update heap
if len(knn_heap) < k:
heapq.heappush(knn_heap, (neg_dist_sq, node.point))
elif neg_dist_sq > knn_heap[0][0]: # Current point is closer than the farthest in heap
heapq.heapreplace(knn_heap, (neg_dist_sq, node.point))
# Determine branches
diff = query_point[split_dim] - node.point[split_dim]
if diff < 0:
close_branch, away_branch = node.left, node.right
else:
close_branch, away_branch = node.right, node.left
# Search close branch
_find_knn_recursive(close_branch, query_point, depth + 1, k, knn_heap)
# Check if away branch needs searching
# Pruning condition: compare distance to hyperplane with distance to K'th neighbor
# K'th neighbor's negative squared distance is knn_heap[0][0] (smallest element)
farthest_dist_sq = -knn_heap[0][0] if knn_heap else float('inf')
if (diff**2) < farthest_dist_sq or len(knn_heap) < k:
_find_knn_recursive(away_branch, query_point, depth + 1, k, knn_heap)
# Example Usage:
# K = 3
# neighbors = find_k_nearest_neighbors(root_node, query, K)
# print(f"\nFinding {K} nearest neighbors for point: {query}")
# for dist, point in neighbors:
# print(f" Point: {point}, Distance: {dist:.4f}")
While k-d trees significantly speed up searches in low-dimensional spaces (e.g., 2D, 3D), their performance degrades as the number of dimensions (k) increases. This is often referred to as the "curse of dimensionality." In high dimensions, the hypercubes defined by the splits become less effective at pruning the search space, and the search complexity can approach the brute-force O(n⋅k). For dimensions roughly greater than 10-20, other structures like Ball Trees or approximate methods (LSH, HNSW) often perform better.
Let's visualize the partitions created by a simple 2D k-d tree. We can represent the splits as lines on a scatter plot.
import plotly.graph_objects as go
def get_bounding_box(points: List[Point]) -> Tuple[Point, Point]:
"""Calculates the bounding box of a set of points."""
all_points_np = np.array(points)
min_coords = np.min(all_points_np, axis=0)
max_coords = np.max(all_points_np, axis=0)
return min_coords, max_coords
def plot_kdtree_partitions(node: Optional[Node], min_coords: Point, max_coords: Point, depth: int = 0, fig=None):
"""Recursively adds partition lines to a Plotly figure."""
if node is None:
return
k = len(node.point)
split_dim = node.split_dim # Use stored split dim
split_val = node.point[split_dim]
# Draw the splitting line/plane
if split_dim == 0: # Vertical line
fig.add_shape(type="line",
x0=split_val, y0=min_coords[1], x1=split_val, y1=max_coords[1],
line=dict(color="#adb5bd", width=1))
# Recurse on children with updated bounds
plot_kdtree_partitions(node.left, min_coords, np.array([split_val, max_coords[1]]), depth + 1, fig)
plot_kdtree_partitions(node.right, np.array([split_val, min_coords[1]]), max_coords, depth + 1, fig)
elif split_dim == 1: # Horizontal line
fig.add_shape(type="line",
x0=min_coords[0], y0=split_val, x1=max_coords[0], y1=split_val,
line=dict(color="#adb5bd", width=1))
# Recurse on children with updated bounds
plot_kdtree_partitions(node.left, min_coords, np.array([max_coords[0], split_val]), depth + 1, fig)
plot_kdtree_partitions(node.right, np.array([min_coords[0], split_val]), max_coords, depth + 1, fig)
# Extend for higher dimensions if needed, though visualization becomes hard
# --- Example Plotting ---
# Generate some 2D data
np.random.seed(42)
points_list_2d = [np.random.rand(2) * 10 for _ in range(50)]
# Build the tree
root_2d = build_kdtree(points_list_2d.copy())
# Create scatter plot
points_np = np.array(points_list_2d)
fig = go.Figure(data=go.Scatter(x=points_np[:, 0], y=points_np[:, 1], mode='markers',
marker=dict(color='#228be6', size=8), name='Data Points'))
# Add query point (optional)
query_pt = np.array([5, 5])
nearest_pt, _ = find_nearest_neighbor(root_2d, query_pt)
fig.add_trace(go.Scatter(x=[query_pt[0]], y=[query_pt[1]], mode='markers',
marker=dict(color='#fa5252', size=10, symbol='x'), name='Query Point'))
if nearest_pt is not None:
fig.add_trace(go.Scatter(x=[nearest_pt[0]], y=[nearest_pt[1]], mode='markers',
marker=dict(color='#e64980', size=10, symbol='star'), name='Nearest'))
# Add partition lines
min_bound, max_bound = get_bounding_box(points_list_2d)
# Add a small margin to bounds for visualization
margin = 1.0
min_bound -= margin
max_bound += margin
plot_kdtree_partitions(root_2d, min_bound, max_bound, fig=fig)
fig.update_layout(title="k-d Tree Partitions (2D Example)",
xaxis_title="Dimension 0", yaxis_title="Dimension 1",
xaxis=dict(range=[min_bound[0], max_bound[0]]),
yaxis=dict(range=[min_bound[1], max_bound[1]]),
width=700, height=600,
showlegend=True)
# To display the plot (e.g., in a Jupyter environment):
# fig.show()
# To get the JSON representation:
# print(fig.to_json()) # This might be very long for many points/partitions
# Shortened example for JSON output format guide (structure only):
# ```plotly
# {"layout": {"title": "k-d Tree Partitions", "xaxis": {"title": "Dim 0"}, "yaxis": {"title": "Dim 1"}, "shapes": [{"type": "line", ...}, ...]}, "data": [{"type": "scatter", "mode": "markers", ...}, ...]}
# ```
Scatter plot showing 50 randomly generated 2D points, the query point (red 'x'), its calculated nearest neighbor (pink star), and the first few partitioning lines (gray) generated by the k-d tree build process. Vertical lines split along dimension 0, horizontal lines split along dimension 1.
This hands-on exercise demonstrates how a relatively simple recursive data structure can provide significant performance improvements for a common ML task like nearest neighbor search. By implementing it yourself, you gain a deeper appreciation for the algorithmic choices involved in optimizing ML workflows beyond relying solely on pre-built library functions. Understanding these structures is valuable when you need to customize algorithms or diagnose performance issues in spatial data problems.
© 2025 ApX Machine Learning