Implement Prototypical Networks by leveraging the powerful representations learned by large foundation models. This method for metric-based meta-learning utilizes pre-computed embeddings from a model as input features to construct prototypes and classify query instances in a few-shot setting. Such an approach bypasses end-to-end meta-training of the entire network, making it computationally efficient and often surprisingly effective.Prerequisites and SetupBefore proceeding, ensure you have the following components ready:Foundation Model Embeddings: Access to a pre-trained foundation model capable of generating meaningful embeddings for your data modality (e.g., ViT for images, BERT for text). You can either load the model itself or, more efficiently, pre-compute and store embeddings for your few-shot dataset. For this exercise, we assume the embeddings are fixed (i.e., the foundation model's weights are frozen).Few-Shot Dataset: A dataset structured into N-way K-shot classification tasks. Each task should consist of:A support set $S = {(x_1, y_1), ..., (x_{N \times K}, y_{N \times K})}$, containing $K$ labeled examples for each of the $N$ classes.A query set $Q = {(x^_1, y^1), ..., (x^*{N \times Q'}, y^*_{N \times Q'}) }$, containing $Q'$ new examples for each of the $N$ classes, used for evaluation.Environment: A Python environment with standard scientific computing libraries (numpy, torch or tensorflow). We'll use PyTorch-like syntax for illustration.Core Implementation StepsThe core idea is to map both support and query examples into the embedding space provided by the foundation model and then perform nearest-prototype classification within that space.Embedding Extraction: Pass the support set examples $x_i \in S$ and query set examples $x^*j \in Q$ through the (frozen) foundation model $f\phi(\cdot)$ to obtain their respective embeddings:Support embeddings: $e_i = f_\phi(x_i)$ for all $(x_i, y_i) \in S$.Query embeddings: $e^j = f\phi(x^_j)$ for all $(x^_j, y^_j) \in Q$.Since the foundation model is large, batch processing and potentially caching these embeddings is important for efficiency, especially during meta-testing or if the same examples appear in multiple tasks.Function to get embeddingsimport torch @torch.no_grad() # Ensure gradients are not computed def get_embeddings(foundation_model, data_loader, device): foundation_model.eval() # Set model to evaluation mode all_embeddings = [] all_labels = [] for inputs, labels in data_loader: inputs = inputs.to(device) # Assuming foundation_model outputs embeddings directly or after a projection layer embeddings = foundation_model(inputs) all_embeddings.append(embeddings.cpu()) all_labels.append(labels.cpu()) return torch.cat(all_embeddings), torch.cat(all_labels) # Example usage within a task loop: # support_loader = # DataLoader for support set (N*K samples) # query_loader = # DataLoader for query set (N*Q' samples) # support_embeddings, support_labels = get_embeddings(model, support_loader, device) # query_embeddings, query_labels = get_embeddings(model, query_loader, device) ```2. Prototype Calculation: For each class $k$ present in the support set, calculate its prototype $c_k$ by averaging the embeddings of its $K$ support examples: $$ c_k = \frac{1}{|S_k|} \sum_{(x_i, y_i) \in S_k} f_\phi(x_i) = \frac{1}{K} \sum_{i: y_i=k} e_i $$ This results in $N$ prototype vectors, each representing the central point of a class in the embedding space.```pythonCalculate prototypesdef calculate_prototypes(support_embeddings, support_labels, classes): # classes: A list or tensor containing the unique class labels for the task num_classes = len(classes) embedding_dim = support_embeddings.size(1) prototypes = torch.zeros(num_classes, embedding_dim, device=support_embeddings.device) for i, k in enumerate(classes): # Select embeddings belonging to class k class_mask = (support_labels == k) class_embeddings = support_embeddings[class_mask] # Calculate the mean embedding prototypes[i] = class_embeddings.mean(dim=0) return prototypes ```3. Query Classification: Classify each query embedding $e^_j$ by finding the prototype $c_k$ that is closest according to a chosen distance metric $d(\cdot, \cdot)$. The predicted class $\hat{y}^_j$ for query example $x^_j$ is: $$ \hat{y}^_j = \underset{k \in {1, ..., N}}{\operatorname{argmin}} ; d(e^*_j, c_k) $$ Common choices for $d$ include the squared Euclidean distance or cosine distance. Squared Euclidean distance is often preferred for its computational simplicity (avoiding square roots) and its equivalence to Euclidean distance for argmin operations.```pythonClassify query pointsdef classify_queries(query_embeddings, prototypes): # Using squared Euclidean distance: ||a - b||^2 # `torch.cdist` computes pairwise distances. # query_embeddings shape: [num_query, embedding_dim] # prototypes shape: [num_classes, embedding_dim] # dists shape: [num_query, num_classes] dists = torch.cdist(query_embeddings, prototypes.unsqueeze(0)).squeeze(0) ** 2 # Use broadcasting for efficiency # Find the index (class) with the minimum distance for each query # predictions shape: [num_query] predictions = torch.argmin(dists, dim=1) return predictions # These are indices relative to the 'prototypes' tensor ``` Remember to map the predicted indices back to the original class labels if necessary.Foundation Model EmbeddingsDimensionality: Foundation model embeddings are typically high-dimensional (e.g., 768, 1024, or more). While Prototypical Networks handle this naturally, be mindful of the computational cost of distance calculations.Normalization: It's often beneficial to L2-normalize both the support and query embeddings before calculating prototypes and distances. This makes the classification rely on the angle (cosine similarity) rather than magnitude, which can be advantageous in high dimensions where magnitudes might vary significantly. If using cosine similarity, normalization is inherent. For Euclidean distance, normalization turns it into a function of cosine similarity: $||u-v||^2_2 = 2 - 2 u \cdot v$ for normalized vectors $u, v$.Metric Choice: While Euclidean distance is standard, cosine distance (or similarity) might be more appropriate if the angular separation between class clusters is more informative than their positions in the embedding space, especially after normalization. Experimentation is important.Visualizing the ProcessThe following diagram illustrates the core workflow for a single few-shot task:digraph G { rankdir=LR; node [shape=box, style=rounded, fontname="sans-serif", fontsize=10]; edge [fontname="sans-serif", fontsize=9]; subgraph cluster_support { label = "Support Set"; style=dashed; color="#adb5bd"; s1 [label="Support Ex 1 (Class A)"]; s2 [label="Support Ex 2 (Class A)"]; s3 [label="Support Ex K (Class A)"]; s4 [label="Support Ex 1 (Class B)"]; s5 [label="Support Ex K (Class B)"]; sn [label="... (N Classes)"]; } subgraph cluster_query { label = "Query Set"; style=dashed; color="#adb5bd"; q1 [label="Query Ex 1"]; q2 [label="Query Ex M"]; } subgraph cluster_embed { label = "Embedding Space (f_phi)"; style=filled; fillcolor="#e9ecef"; color="#adb5bd"; node [shape=ellipse, style=filled, fillcolor="#ffffff"]; se1 [label="e_1"]; se2 [label="e_2"]; se3 [label="e_K (A)"]; se4 [label="e_1 (B)"]; se5 [label="e_K (B)"]; sen [label="..."]; qe1 [label="e*_1"]; qe2 [label="e*_M"]; } subgraph cluster_proto { label = "Prototype Calculation"; style=filled; fillcolor="#dee2e6"; color="#adb5bd"; node [shape=diamond, style=filled, fillcolor="#a5d8ff"]; pa [label="Prototype A\n(Avg of A's e)"]; pb [label="Prototype B\n(Avg of B's e)"]; pn [label="..."]; } subgraph cluster_classify { label = "Classification"; style=filled; fillcolor="#e9ecef"; color="#adb5bd"; node [shape=invhouse, style=filled, fillcolor="#ffec99"]; class1 [label="Classify e*_1\n(min dist to Proto)"]; class2 [label="Classify e*_M\n(min dist to Proto)"]; } # Edges s1 -> se1; s2 -> se2; s3 -> se3; s4 -> se4; s5 -> se5; sn -> sen; {s1, s2, s3, s4, s5, sn} -> foundation_model [label="Apply f_phi", style=invis]; # Grouping edge q1 -> qe1; q2 -> qe2; {q1, q2} -> foundation_model_q [label="Apply f_phi", style=invis]; # Grouping edge {se1, se2, se3} -> pa [lhead=cluster_proto, arrowhead=normal, style=dashed, color="#495057"]; {se4, se5} -> pb [lhead=cluster_proto, arrowhead=normal, style=dashed, color="#495057"]; sen -> pn [lhead=cluster_proto, arrowhead=normal, style=dashed, color="#495057"]; qe1 -> class1 [label="Dist(e*_1, Proto A)\nDist(e*_1, Proto B)\n...", color="#1c7ed6"]; qe2 -> class2 [label="Dist(e*_M, Proto A)\nDist(e*_M, Proto B)\n...", color="#1c7ed6"]; pa -> class1 [style=dotted, color="#1c7ed6"]; pb -> class1 [style=dotted, color="#1c7ed6"]; pn -> class1 [style=dotted, color="#1c7ed6"]; pa -> class2 [style=dotted, color="#1c7ed6"]; pb -> class2 [style=dotted, color="#1c7ed6"]; pn -> class2 [style=dotted, color="#1c7ed6"]; # Invisible nodes for layout help foundation_model [style=invis]; foundation_model_q [style=invis]; {rank=same; s1; q1;} {rank=same; foundation_model; foundation_model_q;} {rank=same; se1; qe1;} {rank=same; pa;} {rank=same; class1;} }Flow diagram illustrating the Prototypical Network process using pre-computed foundation model embeddings for a single N-way K-shot task. Support and query examples are embedded, prototypes are computed from support embeddings, and query embeddings are classified based on proximity to prototypes.Evaluation and Next StepsEvaluate the performance by calculating the accuracy (or other relevant metrics) of the predictions made on the query sets across multiple held-out tasks (meta-testing).This practical implementation demonstrates the effectiveness of combining powerful pre-trained representations with the simple, yet effective, logic of Prototypical Networks. While we used fixed embeddings here, more advanced approaches might involve meta-learning the embedding function itself or fine-tuning parts of the foundation model during the meta-training phase, topics explored in other parts of this course. Comparing the results obtained with different distance metrics or normalization strategies can provide further insights into the geometry of the foundation model's embedding space for few-shot learning.