While Prototypical Networks compute class prototypes and use a fixed distance metric in the embedding space, Relation Networks (RNs) introduce a different approach. Proposed by Sung et al. (2018), RNs learn a deep non-linear metric function directly, aiming to determine the similarity or "relation" between query examples and support examples. Instead of assuming a simple metric like Euclidean distance is sufficient, RNs posit that a dedicated neural network can better capture the complex relationships needed for few-shot classification, especially when dealing with intricate visual or semantic features.
The core idea is to train a network that explicitly outputs a scalar similarity score between pairs of feature representations. A typical Relation Network consists of two main components:
Embedding Module (fϕ): This module, often a convolutional neural network (CNN) for vision tasks or a transformer encoder for sequence data, maps raw inputs (both support set examples xs and query examples xq) into feature embeddings.
embeddings=fϕ(xs) embeddingq=fϕ(xq)When adapting foundation models, fϕ might be the pre-trained model itself (potentially frozen) or a part of it.
Relation Module (gψ): This module takes pairs of embeddings (typically a query embedding combined with support embeddings) and outputs a scalar relation score between 0 and 1, indicating the degree of similarity relevant to the classification task. A common practice for an N-way, K-shot task is to first aggregate the support embeddings for each class ci (e.g., by summing or averaging) to get a class representation si:
si=K1k=1∑Kfϕ(xs(i,k))Then, the relation module processes the concatenation of the query embedding q=fϕ(xq) and each class representation si:
ri=gψ(Combine(si,q))The Combine function is often simple concatenation. The relation module gψ itself is usually a smaller neural network, like a few convolutional layers followed by fully connected layers, or just a multi-layer perceptron (MLP). It's designed to learn a task-specific similarity function.
Flow of information in a Relation Network for one class ci. Support and query images are embedded, support embeddings are aggregated, combined with the query embedding, and fed into the relation module to produce a similarity score. This process is repeated for all classes in the support set.
RNs are trained episodically, similar to other meta-learning methods. In each episode, a task (e.g., a C-way, K-shot classification problem) is sampled. The network processes the support set and query examples for that task. The objective function typically aims to push the relation score ri towards 1 if the query example xq belongs to class ci, and towards 0 otherwise. A common loss function is Mean Squared Error (MSE):
L=task T∑(xq,yq)∈Queries(T)∑i=1∑C(ri−1yq=ci)2Here, ri is the predicted relation score for class ci given the query xq, and 1yq=ci is an indicator function that is 1 if the true label yq matches class ci, and 0 otherwise. This loss is backpropagated through both the relation module gψ and the embedding module fϕ, allowing both the feature representation and the similarity function to be learned jointly.
Relation Networks differ significantly from Prototypical Networks:
When applying Relation Networks in the context of large foundation models, the embedding module fϕ is naturally replaced by the foundation model.
The high dimensionality of embeddings from foundation models (e.g., 768, 1024, or more dimensions) needs consideration when designing the relation module gψ. A simple MLP might struggle with the curse of dimensionality or become computationally burdensome. Techniques like dimensionality reduction (if feasible without losing critical information) or carefully structured relation modules (e.g., using attention mechanisms or factorization) might be necessary.
Strengths:
Limitations:
Relation Networks provide a powerful alternative within metric-based meta-learning, shifting complexity from solely learning good embeddings to jointly learning embeddings and a flexible comparison mechanism. Their effectiveness with foundation models hinges on efficiently leveraging the pre-trained representations while managing the computational cost of the learned relation module.
© 2025 ApX Machine Learning