While neighborhood and graph sampling methods, as discussed previously, offer ways to train GNNs on large graphs by operating on mini-batches, they often involve significant preprocessing per batch or introduce variance due to the sampling process. An alternative strategy aims to simplify the batch construction and reduce computational overhead by partitioning the graph itself: subgraph and clustering methods, exemplified by Cluster-GCN.
The core idea behind Cluster-GCN is elegantly simple: if we can partition the graph's nodes into a smaller number of clusters such that nodes within a cluster are densely connected, maybe we can train the GNN on batches formed by these clusters. Instead of sampling individual nodes or edges, we sample entire clusters (or groups of clusters) and perform GNN computations on the subgraphs induced by the nodes within the sampled cluster(s).
Cluster-GCN involves two main stages:
Graph Clustering (Preprocessing): Before training begins, the large graph G=(V,E) is partitioned into c clusters V1,V2,...,Vc such that V=V1∪V2∪...∪Vc and Vi∩Vj=∅ for i=j. This partitioning is typically done using established graph clustering algorithms like METIS or Louvain, which aim to minimize the number of edges between clusters (edge cuts) while keeping the clusters relatively balanced in size. The goal is to create partitions where most of the graph's connectivity is preserved within the clusters.
Stochastic Batch Training: During training, instead of sampling nodes, we sample one or more clusters to form a mini-batch. Let Vb=Vt1∪...∪Vtk be the set of nodes belonging to the k clusters selected for batch b. We then construct the subgraph Gb=(Vb,Eb) induced by the nodes in Vb. This means Eb contains only the edges from the original graph E that connect two nodes both present in Vb. The GNN's forward and backward passes for this batch are computed using only this subgraph Gb and its corresponding adjacency matrix Ab.
A graph partitioned into three clusters. In Cluster-GCN, a mini-batch might consist of the nodes and intra-cluster edges from a single cluster (e.g., Cluster 2, highlighted with double circles). Inter-cluster edges (dashed lines) are ignored during GNN computation within that batch.
The efficiency gain comes from the fact that GNN computations (like the graph convolution A′HW, where A′ is the normalized adjacency matrix) are performed on the much smaller adjacency matrix Ab of the subgraph Gb, rather than the full graph's adjacency matrix A. If the clusters are reasonably small and the number of clusters c is manageable, each training step becomes significantly faster and requires less memory compared to full-graph training.
The construction of the batch itself is also simplified. Once the initial clustering is done, creating a batch involves merely selecting a pre-defined cluster (or a few clusters) and fetching the corresponding node features and subgraph structure. This contrasts with neighborhood sampling, which requires potentially expensive neighborhood lookups for each node in the batch during training.
The main simplification, and thus the source of approximation in Cluster-GCN, is the omission of edges connecting nodes in the current batch (Vb) to nodes outside the batch (V∖Vb). When the GNN performs message passing within Gb, it only aggregates information from neighbors within Gb. Information from neighbors residing in other clusters is lost for that specific batch update.
How detrimental is this? It depends heavily on the quality of the clustering. If the clustering algorithm successfully groups nodes such that most connections are internal to the clusters (i.e., the graph has a strong community structure and the algorithm finds it), then relatively few edges are dropped in each batch, and the approximation might be minor. If the graph lacks clear community structure or the clustering is poor, many edges might be ignored per batch, potentially harming the GNN's ability to learn global patterns.
The original Cluster-GCN paper proposed a strategy to mitigate this: instead of using just one cluster per batch, combine a small number of randomly chosen clusters (k>1) into a single batch Vb=Vt1∪...∪Vtk. This increases the chance that nodes near the boundary of one cluster can still receive messages from neighbors that happen to fall into another cluster included in the same batch, thereby incorporating more inter-cluster edges into the computation for that step.
Let's consider a standard GCN layer update for the full graph: H(l+1)=σ(D^−1/2A^D^−1/2H(l)W(l)) where A^=A+I is the adjacency matrix with self-loops, D^ is the corresponding diagonal degree matrix, H(l) is the node feature matrix at layer l, and W(l) is the layer's weight matrix.
In Cluster-GCN, for a batch b defined by nodes Vb, the update is performed on the subgraph Gb=(Vb,Eb). Let Ab be the adjacency matrix of this subgraph (containing only edges within Vb). The update for nodes in the batch becomes: HVb(l+1)=σ(D^b−1/2A^bD^b−1/2HVb(l)W(l)) Here, A^b=Ab+IVb is the subgraph adjacency matrix with self-loops added for nodes in Vb, D^b is its degree matrix, and HVb(l) represents the rows of H(l) corresponding to nodes in Vb. Notice that the computation only involves the features HVb(l) and the structure A^b, significantly reducing the computational cost if ∣Vb∣≪∣V∣.
Advantages:
Disadvantages:
Comparison to Sampling Methods:
Libraries like PyTorch Geometric (PyG) offer convenient data loaders for Cluster-GCN. For example, torch_geometric.loader.ClusterLoader
takes a graph data object and performs the clustering (using the torch_cluster
package which interfaces with METIS) and generates batches corresponding to clusters or groups of clusters. This abstracts away much of the complexity of partitioning and subgraph construction.
# Example using PyTorch Geometric
from torch_geometric.loader import ClusterData, ClusterLoader
from torch_geometric.datasets import Planetoid # Or your large graph dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
# Perform clustering (preprocessing)
cluster_data = ClusterData(data, num_parts=100, recursive=False, save_dir=dataset.processed_dir)
# Create loader that yields batches of clusters
train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True)
# Training loop
model = GCN(...) # Your GNN model
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for batch in train_loader:
optimizer.zero_grad()
# The 'batch' object is already the subgraph for the selected clusters
output = model(batch.x, batch.edge_index)
# Calculate loss only for nodes in the current batch
loss = F.nll_loss(output[batch.train_mask], batch.y[batch.train_mask])
loss.backward()
optimizer.step()
In summary, Cluster-GCN provides an effective strategy for scaling GNN training by leveraging graph clustering. It trades off the exactness of message passing at cluster boundaries for significant gains in computational speed and memory efficiency, making it a valuable technique for tackling truly large-scale graph learning problems, especially those with inherent community structures.
© 2025 ApX Machine Learning