As we've discussed, distributing the retrieval workload is fundamental to building RAG systems that can operate over massive datasets. Sharding your vector index is a primary technique to achieve this distribution, allowing for parallel processing of queries and storage of vector embeddings across multiple nodes or processes. This practical exercise will guide you through a simplified implementation of a sharded vector index using FAISS to simulate the core mechanics. While we'll use a local setup for clarity, the principles demonstrated are directly applicable to large-scale distributed vector databases and search systems.
Our goal is to understand how to partition data, route indexing and query operations to the appropriate shards, and aggregate results. This hands-on experience will solidify your understanding of the distributed retrieval strategies covered earlier.
For this practical, we'll use Python with the FAISS library from Facebook AI Research for efficient similarity search, and NumPy for numerical operations. FAISS allows us to create and manage vector indexes in memory, making it ideal for demonstrating sharding logic without the overhead of a full distributed database setup.
Imagine we have a large collection of document embeddings that we need to index. Instead of a single, monolithic index, we will create N_SHARDS
smaller indexes. Each document embedding will be assigned to one of these shards.
Main Components:
Let's assume our embeddings are D
-dimensional vectors.
First, we need to initialize our shards. In a real distributed system, each shard might be a separate server or a process managing a FAISS index instance. Here, we'll simulate this with a list of FAISS indexes.
import faiss
import numpy as np
# Configuration
N_SHARDS = 4
D_EMBEDDING = 128 # Dimensionality of embeddings
K_NEIGHBORS = 5 # Number of neighbors to retrieve
# Initialize shards
# For simplicity, we use IndexFlatL2, suitable for smaller datasets.
# In practice, you'd use more advanced index types like IndexIVFPQ for large scale.
shards = [faiss.IndexFlatL2(D_EMBEDDING) for _ in range(N_SHARDS)]
shard_doc_ids = [[] for _ in range(N_SHARDS)] # To store original doc IDs per shard
print(f"Initialized {N_SHARDS} shards, each for {D_EMBEDDING}-dimensional vectors.")
Next, we define a sharding function. A simple modulo operation on a document's unique identifier is a common approach for distributing data relatively evenly, assuming IDs are well-distributed.
def get_shard_index(doc_id_numeric, num_shards):
"""Determines the shard index for a given numeric document ID."""
return doc_id_numeric % num_shards
For this practical, we'll assume doc_id_numeric
is an integer. If you have string IDs, you would first hash them to an integer.
Now, let's simulate ingesting some data. We'll generate random embeddings and assign them document IDs. Each embedding, along with its ID, will be routed to the appropriate shard.
NUM_DOCUMENTS = 10000
np.random.seed(42) # for reproducibility
# Generate dummy document embeddings and IDs
# In a real system, these embeddings come from your embedding model
all_embeddings = np.random.rand(NUM_DOCUMENTS, D_EMBEDDING).astype('float32')
# Assign sequential numeric IDs for simplicity
all_doc_ids = np.arange(NUM_DOCUMENTS)
# Ingest data into shards
for i in range(NUM_DOCUMENTS):
doc_id = all_doc_ids[i]
embedding = all_embeddings[i:i+1] # FAISS expects a 2D array
shard_idx = get_shard_index(doc_id, N_SHARDS)
shards[shard_idx].add(embedding)
shard_doc_ids[shard_idx].append(doc_id) # Store mapping from FAISS index to original ID
# Verify shard populations
for i, shard in enumerate(shards):
print(f"Shard {i} contains {shard.ntotal} embeddings.")
At this point, our NUM_DOCUMENTS
embeddings are distributed across N_SHARDS
FAISS indexes. Each shard is smaller and can be managed independently.
When a query arrives, it must be dispatched to all shards, as any shard could potentially contain relevant vectors. This is a common pattern known as "scatter-gather."
# Simulate a query embedding
query_embedding = np.random.rand(1, D_EMBEDDING).astype('float32')
all_shard_distances = []
all_shard_original_ids = []
# 1. Scatter query to all shards & 2. Gather results
for i, shard_index_instance in enumerate(shards):
# Perform search on the current shard
# We ask for K_NEIGHBORS from each shard, might need more if K_NEIGHBORS is small
# and results are sparse. For top-K, often K_shard_query > K_final_query.
distances, faiss_indices = shard_index_instance.search(query_embedding, K_NEIGHBORS)
# 3. Map to Original IDs
# faiss_indices contains indices relative to THAT shard.
# -1 indicates no more neighbors found within that shard for that query vector.
for j in range(distances.shape[1]): # Iterate through neighbors found for the query
if faiss_indices[0, j] != -1: # If a valid neighbor was found
original_doc_id = shard_doc_ids[i][faiss_indices[0, j]]
all_shard_distances.append(distances[0, j])
all_shard_original_ids.append(original_doc_id)
# 4. Aggregate & Re-rank
if all_shard_distances:
# Combine distances and original IDs
results = sorted(zip(all_shard_distances, all_shard_original_ids))
# Get the global top-K results
final_top_k_results = results[:K_NEIGHBORS]
print(f"\nTop {K_NEIGHBORS} results from sharded index:")
for dist, doc_id in final_top_k_results:
print(f" Doc ID: {doc_id}, Distance: {dist:.4f}")
else:
print("\nNo results found across any shards.")
This code simulates the fundamental operations: sharding data during ingestion and performing a scatter-gather query followed by result aggregation.
The process described above can be visualized as follows:
Query processing in a sharded vector index system. The query router distributes the search to all shards, and an aggregator combines partial results to produce the final list.
While sharding improves scalability, it introduces its own set of considerations for an expert practitioner:
N_SHARDS
): Choosing the right number of shards is important. Too few, and you don't get enough parallelism. Too many, and the overhead of managing shards and aggregating results might increase latency, especially if each shard returns very few results for a typical query. This often depends on the total data size, query volume, and hardware resources.k
from each Shard: In the example, we queried for K_NEIGHBORS
from each shard. This is a simplification. To guarantee finding the true global top K_NEIGHBORS
, you generally need to retrieve more than K_NEIGHBORS
items from each shard (e.g., K_NEIGHBORS
or K_NEIGHBORS + buffer_size
) before aggregation, especially if distance distributions vary significantly across shards or if K_NEIGHBORS
is small. The exact number depends on data distribution and desired recall.This hands-on practical demonstrated the core principles of sharding a vector index. You've seen how to distribute data across multiple logical (or, in a real system, physical) shards and how to implement a scatter-gather query pattern with result aggregation.
The sharding technique here is a foundation for building high-throughput, low-latency retrieval components in large-scale RAG systems. While we used FAISS for local simulation, these concepts apply directly when working with distributed vector databases like Milvus, Weaviate, Pinecone, or when building custom solutions on top of Kubernetes and frameworks like Jina or Ray. Understanding these mechanics is key for designing and troubleshooting the performance of your distributed retrieval pipelines. As you progress, you'll combine sharding with other techniques like replication, sophisticated indexing structures within shards (e.g., IVFADC in FAISS), and advanced re-ranking to build truly resilient and performant systems.
Was this section helpful?
© 2025 ApX Machine Learning