This section provides practical guidance on implementing Prototypical Networks, specifically leveraging the powerful representations learned by large foundation models. Building on the theoretical discussion of metric-based meta-learning, we'll focus on using pre-computed embeddings from a foundation model as the input features for constructing prototypes and classifying query instances in a few-shot setting. This approach bypasses the need for end-to-end meta-training of the entire network, making it computationally efficient and often surprisingly effective.
Before proceeding, ensure you have the following components ready:
numpy
, torch
or tensorflow
). We'll use PyTorch-like syntax for illustration.The core idea is to map both support and query examples into the embedding space provided by the foundation model and then perform nearest-prototype classification within that space.
Embedding Extraction: Pass the support set examples xi∈S and query set examples xj∗∈Q through the (frozen) foundation model fϕ(⋅) to obtain their respective embeddings:
Since the foundation model is large, batch processing and potentially caching these embeddings is important for efficiency, especially during meta-testing or if the same examples appear in multiple tasks.
# Conceptual: Function to get embeddings
import torch
@torch.no_grad() # Ensure gradients are not computed
def get_embeddings(foundation_model, data_loader, device):
foundation_model.eval() # Set model to evaluation mode
all_embeddings = []
all_labels = []
for inputs, labels in data_loader:
inputs = inputs.to(device)
# Assuming foundation_model outputs embeddings directly or after a projection layer
embeddings = foundation_model(inputs)
all_embeddings.append(embeddings.cpu())
all_labels.append(labels.cpu())
return torch.cat(all_embeddings), torch.cat(all_labels)
# Example usage within a task loop:
# support_loader = # DataLoader for support set (N*K samples)
# query_loader = # DataLoader for query set (N*Q' samples)
# support_embeddings, support_labels = get_embeddings(model, support_loader, device)
# query_embeddings, query_labels = get_embeddings(model, query_loader, device)
Prototype Calculation: For each class k present in the support set, calculate its prototype ck by averaging the embeddings of its K support examples:
ck=∣Sk∣1(xi,yi)∈Sk∑fϕ(xi)=K1i:yi=k∑eiThis results in N prototype vectors, each representing the central point of a class in the embedding space.
# Conceptual: Calculate prototypes
def calculate_prototypes(support_embeddings, support_labels, classes):
# classes: A list or tensor containing the unique class labels for the task
num_classes = len(classes)
embedding_dim = support_embeddings.size(1)
prototypes = torch.zeros(num_classes, embedding_dim, device=support_embeddings.device)
for i, k in enumerate(classes):
# Select embeddings belonging to class k
class_mask = (support_labels == k)
class_embeddings = support_embeddings[class_mask]
# Calculate the mean embedding
prototypes[i] = class_embeddings.mean(dim=0)
return prototypes
Query Classification: Classify each query embedding ej∗ by finding the prototype ck that is closest according to a chosen distance metric d(⋅,⋅). The predicted class y^j∗ for query example xj∗ is:
y^j∗=k∈{1,...,N}argmind(ej∗,ck)Common choices for d include the squared Euclidean distance or cosine distance. Squared Euclidean distance is often preferred for its computational simplicity (avoiding square roots) and its equivalence to Euclidean distance for argmin operations.
# Conceptual: Classify query points
def classify_queries(query_embeddings, prototypes):
# Using squared Euclidean distance: ||a - b||^2
# `torch.cdist` computes pairwise distances.
# query_embeddings shape: [num_query, embedding_dim]
# prototypes shape: [num_classes, embedding_dim]
# dists shape: [num_query, num_classes]
dists = torch.cdist(query_embeddings, prototypes.unsqueeze(0)).squeeze(0) ** 2 # Use broadcasting for efficiency
# Find the index (class) with the minimum distance for each query
# predictions shape: [num_query]
predictions = torch.argmin(dists, dim=1)
return predictions # These are indices relative to the 'prototypes' tensor
Remember to map the predicted indices back to the original class labels if necessary.
The following diagram illustrates the core workflow for a single few-shot task:
Flow diagram illustrating the Prototypical Network process using pre-computed foundation model embeddings for a single N-way K-shot task. Support and query examples are embedded, prototypes are computed from support embeddings, and query embeddings are classified based on proximity to prototypes.
Evaluate the performance by calculating the accuracy (or other relevant metrics) of the predictions made on the query sets across multiple held-out tasks (meta-testing).
This practical implementation demonstrates the effectiveness of combining powerful pre-trained representations with the simple, yet effective, logic of Prototypical Networks. While we used fixed embeddings here, more advanced approaches might involve meta-learning the embedding function itself or fine-tuning parts of the foundation model during the meta-training phase, topics explored in other parts of this course. Comparing the results obtained with different distance metrics or normalization strategies can provide further insights into the geometry of the foundation model's embedding space for few-shot learning.
© 2025 ApX Machine Learning