Supervised learning on graphs requires abundant labeled data, which is often expensive or impractical to obtain in many scenarios like social networks, biological networks, or knowledge graphs. Self-Supervised Learning (SSL) provides an alternative by learning meaningful node or graph representations directly from the unlabeled graph structure and features. The core idea is to design "pretext" tasks that can be solved using the graph data itself, forcing the GNN encoder to capture essential structural and semantic information without relying on external labels. These learned representations can then be transferred effectively to various downstream tasks through fine-tuning or direct use as features.
The Rationale for SSL on Graphs
Graph data possesses rich intrinsic structure and features that can be exploited for self-supervision. Unlike images or text where augmentations like rotation or cropping have standard interpretations, defining meaningful augmentations and pretext tasks for graphs requires careful consideration of their unique properties. The goal is to generate supervisory signals from the graph itself to train a GNN encoder.
Common SSL Strategies for Graphs
Two primary categories dominate SSL on graphs: contrastive methods and predictive methods.
Contrastive Learning
Contrastive learning aims to learn representations by maximizing the agreement between differently augmented "views" of the same graph entity (node, subgraph, or whole graph) while simultaneously minimizing the agreement with views from different entities ("negative samples").
-
Data Augmentation: Creating different views of the graph is fundamental. Common graph augmentation techniques include:
- Node Dropping: Randomly removing a fraction of nodes.
- Edge Perturbation: Randomly adding or removing edges.
- Attribute Masking: Randomly masking out a fraction of node features.
- Subgraph Sampling: Extracting subgraphs centered around specific nodes.
The choice of augmentation significantly impacts the learned representations, often requiring domain-specific tuning.
-
Contrastive Objective: The GNN encoder, fθ, maps augmented graph views to embeddings. A popular objective is InfoNCE (Noise Contrastive Estimation), which encourages the similarity (e.g., cosine similarity) between positive pairs (different views of the same entity, zi,zj) to be high, while being low for negative pairs (views from different entities, zi,zk). Often, a non-linear projection head, gϕ, is applied to the embeddings (h=gϕ(z)) before calculating the contrastive loss:
Li=−logexp(sim(hi,hj)/τ)+∑k=iexp(sim(hi,hk)/τ)exp(sim(hi,hj)/τ)
Here, τ is a temperature hyperparameter scaling the similarities.
-
Methods:
- GraphCL: Applies various augmentations (node dropping, edge perturbation, etc.) at the graph level for graph classification pre-training.
- GRACE (Graph Representation Learning with Adaptive Corruptions and Estimation): Focuses on node-level contrastive learning by generating two views via augmentations and contrasting node representations within and across views.
- InfoGraph: Contrasts graph-level representations with patch-level (node) representations derived from the same graph to maximize mutual information.
Flow of contrastive self-supervised learning on graphs. Two augmented views of an anchor graph/node are generated and passed through a shared GNN encoder and projection head. The resulting representations are pulled closer together, while being pushed apart from representations of negative samples, guided by the contrastive loss.
Predictive Learning
Predictive methods define pretext tasks based on predicting certain properties of the graph or its components.
- Attribute Masking: Similar to BERT in NLP, this involves masking some node features and training the GNN to predict the masked values based on the node's neighborhood context and remaining features. This forces the GNN to learn local structure and feature dependencies.
- Context Prediction: This involves predicting relationships or properties based on local context. For instance, predicting the distance between two nodes or whether two nodes belong to the same sampled subgraph.
- Structural Prediction: Pretext tasks can involve predicting graph structural properties, such as node degrees, clustering coefficients, or even predicting the existence of edges (link prediction used as a pretext task).
These methods often use standard loss functions like Cross-Entropy for classification-based prediction tasks or Mean Squared Error for regression-based tasks.
Utilizing Pre-trained GNNs
Once the GNN encoder is pre-trained via SSL, it can be adapted for downstream tasks:
- Fine-tuning: The entire pre-trained GNN (or parts of it) is fine-tuned end-to-end with a task-specific head (e.g., a classification layer) on a smaller labeled dataset.
- Feature Extraction: The pre-trained GNN is used as a fixed feature extractor. The generated node or graph embeddings are fed into a separate downstream model (e.g., SVM, MLP).
Advantages
Using SSL for GNNs offers several benefits:
- Reduced Label Dependency: Enables learning from extensive amounts of readily available unlabeled graph data.
- Improved Generalization: Representations learned via SSL often capture more fundamental graph properties, potentially leading to better generalization on downstream tasks compared to purely supervised training, especially with limited labels.
- Initialization: Provides a strong weight initialization for subsequent supervised fine-tuning.
However, designing effective SSL strategies requires careful thought:
- Augmentation Design: Graph augmentations must preserve relevant structural or semantic information while creating sufficiently diverse views. Poor augmentations can harm performance.
- Pretext Task Selection: The chosen pretext task should align with the properties needed for the target downstream tasks.
- Negative Sampling: In contrastive methods, selecting informative negative samples is important for efficient learning. Naive random sampling can be inefficient.
- Computational Cost: Pre-training on large graphs can be computationally intensive.
SSL is a rapidly evolving area in graph machine learning, providing powerful tools for representation learning when labeled data is scarce. By leveraging the inherent structure and features of graphs, SSL enables the training of versatile GNN models applicable to a wide range of complex graph analysis tasks discussed in this chapter.