With our data prepared, chunked, vectorized, and stored in a vector database (as discussed in Chapter 3), we can now implement the first active component of our RAG pipeline: the retriever. The retriever's responsibility is to take a user's query, understand its semantic meaning, and fetch the most relevant pieces of information from our knowledge base.
This process hinges on the vector embeddings we generated earlier. We'll convert the user's query into a vector using the same embedding model used for our documents. Then, we'll use the vector database to find the document chunk vectors that are "closest" to the query vector in the high-dimensional embedding space, typically using a similarity metric like cosine similarity.
First, ensure you have the necessary libraries installed. For this example, we'll use sentence-transformers
for generating embeddings and chromadb
as our local vector store client. If you haven't already installed them:
pip install sentence-transformers chromadb
We'll assume you have already populated a ChromaDB collection named "documents"
with your chunked data and embeddings during the data preparation phase (Chapter 3). You would also need the same embedding model that was used to create those stored embeddings.
Let's start by importing the required libraries and initializing the embedding model and the ChromaDB client. We need to load the exact same embedding model used during the indexing phase to ensure consistency between query embeddings and document embeddings.
from sentence_transformers import SentenceTransformer
import chromadb
# Load the embedding model (use the same one as used for indexing)
# Example: 'all-MiniLM-L6-v2' is a popular choice for balance of speed/performance
try:
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print("Embedding model loaded successfully.")
except Exception as e:
print(f"Error loading sentence transformer model: {e}")
# Handle error appropriately, maybe exit or use a fallback
# Initialize the ChromaDB client
# Using PersistentClient to save data to disk
client = chromadb.PersistentClient(path="./chroma_db") # Specify a directory
# Get the existing collection (assuming it was created in Chapter 3)
try:
# Ensure the collection exists before trying to get it
collections = client.list_collections()
if "documents" in [col.name for col in collections]:
collection = client.get_collection("documents")
print(f"Connected to existing collection: 'documents' with {collection.count()} items.")
else:
print("Error: Collection 'documents' not found. Please ensure data is indexed.")
collection = None # Set collection to None to handle this case later
except Exception as e:
print(f"Error connecting to ChromaDB collection: {e}")
collection = None # Ensure collection is None if connection fails
Here, we initialize SentenceTransformer
with a model name. Using chromadb.PersistentClient
allows the database state to persist across runs. We then attempt to connect to our pre-existing "documents"
collection. It's important to include error handling, as the model might fail to download or the database connection might encounter issues.
Now, let's define a function that encapsulates the retrieval logic. This function will take the user query and the number of documents to retrieve (top_k
) as input.
The retrieval process: The user query is converted into a vector, which is then used to search the vector database for the most similar document chunk vectors.
def retrieve_relevant_chunks(query: str, top_k: int = 5):
"""
Retrieves the top_k most relevant document chunks for a given query.
Args:
query (str): The user's query string.
top_k (int): The number of relevant chunks to retrieve.
Returns:
list: A list of retrieved document chunks (or None if an error occurs).
Each item in the list could be the text content, or a dict
containing text and metadata. Returns empty list if no results.
"""
if collection is None:
print("Error: Vector database collection is not available.")
return None
if embedding_model is None:
print("Error: Embedding model is not loaded.")
return None
try:
# 1. Generate embedding for the input query
query_embedding = embedding_model.encode(query).tolist()
# 2. Query the vector database
# We ask for 'documents' and potentially 'metadatas' and 'distances'
results = collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
include=['documents', 'metadatas', 'distances']
)
# 3. Process and return the results
# The 'documents' field contains the text of the retrieved chunks
# Results format depends slightly on the client library version,
# but generally provides lists within a dictionary.
retrieved_docs = results.get('documents', [[]])[0]
retrieved_metadatas = results.get('metadatas', [[]])[0]
retrieved_distances = results.get('distances', [[]])[0]
# Combine the results for easier use later
processed_results = []
for i, doc in enumerate(retrieved_docs):
processed_results.append({
"text": doc,
"metadata": retrieved_metadatas[i] if retrieved_metadatas else {},
"distance": retrieved_distances[i] if retrieved_distances else None
})
print(f"Retrieved {len(processed_results)} chunks.")
return processed_results
except Exception as e:
print(f"Error during retrieval: {e}")
return None
# Example Usage (assuming collection and model are loaded)
if collection and embedding_model:
user_query = "What are the benefits of using RAG?"
retrieved_information = retrieve_relevant_chunks(user_query, top_k=3)
if retrieved_information:
print(f"\nTop {len(retrieved_information)} results for query: '{user_query}'")
for i, info in enumerate(retrieved_information):
print(f"Rank {i+1} (Distance: {info['distance']:.4f}):")
# Limit printing length for brevity
print(f" Text: {info['text'][:250]}...")
print(f" Metadata: {info['metadata']}")
print("-" * 20)
else:
print("Retrieval failed or returned no results.")
In this retrieve_relevant_chunks
function:
query
using embedding_model.encode()
. It's essential this is the same model used to embed the documents stored in ChromaDB. We convert it to a list (.tolist()
) as some clients expect that format.collection.query()
method. We pass the query_embeddings
(note: it expects a list of embeddings, even if there's only one query). n_results=top_k
specifies how many results we want. include=['documents', 'metadatas', 'distances']
tells ChromaDB to return the original text content associated with the vectors, any metadata stored alongside them (like source document name or page number), and the distance scores (lower usually means more similar for metrics like Euclidean, higher for cosine similarity, though Chroma often returns a cosine distance where lower is better).None
or empty lists.try...except
block is included to catch potential issues during embedding generation or database querying.The example usage demonstrates how to call this function and print the retrieved text snippets and their metadata. The distance score gives an indication of how relevant the database considered each chunk to be relative to the query.
This implemented retriever now serves as a bridge between the user's query and the stored knowledge base. The list of relevant text chunks it returns is the crucial context that will be "augmented" onto the original query before being sent to the Large Language Model in the next step.
© 2025 ApX Machine Learning