Metric-based meta-learning methods aim to learn an embedding function fϕ that maps inputs to a representation space where classification for new, few-shot tasks can be performed using simple distance metrics. The effectiveness of algorithms like Prototypical Networks hinges entirely on the quality of this embedding space. Deep metric learning provides the tools and objective functions necessary to train fϕ during the meta-training phase to achieve this desired structure.
The core idea is to explicitly optimize the embedding function ϕ such that embeddings of data points from the same class are pulled closer together, while embeddings of data points from different classes are pushed further apart. This contrasts with standard classification losses (like cross-entropy) which focus primarily on separating class boundaries without necessarily enforcing tight intra-class clustering or large inter-class margins in the embedding space itself.
One fundamental approach is using a contrastive loss. This loss operates on pairs of examples (x1,x2). If the pair belongs to the same class (y=1), the loss encourages their embeddings fϕ(x1) and fϕ(x2) to be close. If they belong to different classes (y=0), the loss pushes their embeddings apart, but only if they are closer than a predefined margin m.
The formulation for a pair is typically:
Lcontrastive(x1,x2,y)=y⋅d(fϕ(x1),fϕ(x2))2+(1−y)⋅max(0,m−d(fϕ(x1),fϕ(x2)))2Here, d(⋅,⋅) represents a distance function, often the Euclidean distance, and m>0 is the margin hyperparameter. The margin ensures that negative pairs only contribute to the loss if their distance is smaller than m, preventing the model from expending effort pushing already well-separated pairs further apart.
While conceptually simple, contrastive loss requires careful sampling of pairs. Randomly sampling pairs often leads to many uninformative pairs (e.g., very dissimilar negative pairs or very similar positive pairs), slowing down convergence. Strategies for mining "hard" or "semi-hard" pairs (pairs that are difficult to classify correctly) are often necessary but add complexity.
Triplet loss addresses some limitations of contrastive loss by considering relative distances within a triplet of examples: an anchor (a), a positive example (p) from the same class as the anchor, and a negative example (n) from a different class. The goal is to ensure the anchor is closer to the positive than it is to the negative, again by at least a margin m.
The loss function is defined as:
Ltriplet(a,p,n)=max(0,d(fϕ(a),fϕ(p))2−d(fϕ(a),fϕ(n))2+m)The loss is zero if the distance to the negative d(fϕ(a),fϕ(n))2 is already greater than the distance to the positive d(fϕ(a),fϕ(p))2 by the margin m. Otherwise, the loss penalizes the model, pushing the negative example further away and/or pulling the positive example closer to the anchor.
A triplet consists of an anchor, a positive (same class), and a negative (different class). The loss encourages the distance between the anchor and positive to be smaller than the distance between the anchor and negative by a margin m. Negative examples violating this margin (like N2) incur a loss.
Similar to contrastive loss, the performance of triplet loss heavily depends on the strategy used for selecting triplets. Random triplets are often too "easy" (the negative is already far away), resulting in zero loss and slow learning. Effective training typically relies on triplet mining:
Online mining (selecting triplets within each mini-batch) is common for efficiency.
Beyond these two, other losses exist, often designed to utilize more information than simple pairs or triplets:
In the context of meta-learning for foundation models, these deep metric learning losses are typically employed during the meta-training stage. The foundation model (or a part of it) acts as the backbone for the embedding function fϕ. The meta-training dataset, composed of various tasks, is used to optimize ϕ via one of these losses. The objective is to pre-train fϕ such that it produces embeddings that are inherently well-structured for downstream few-shot classification. When presented with a new task (support set) during meta-testing, the pre-computed embeddings fϕ(x) for the support examples can be used directly (e.g., to calculate prototypes) and compared against query embeddings using simple distance calculations.
Adapting these techniques for high-dimensional embeddings from foundation models requires careful consideration:
By optimizing the embedding function fϕ with objectives like contrastive or triplet loss during meta-training, metric-based meta-learning methods prepare foundation models to produce representations where new classes can be effectively discriminated based on distances, even with very few examples per class.
© 2025 ApX Machine Learning