When working with standard deep learning models, like CNNs for images, creating batches of data is straightforward. You can stack multiple fixed-size images into a single tensor. However, graphs are not as uniform. They have varying numbers of nodes and edges, making it impossible to stack them into a neat, rectangular tensor. A common first thought might be to pad their adjacency matrices to a maximum size, but this would lead to extremely sparse and memory-intensive tensors, which is highly inefficient.
PyTorch Geometric solves this challenge with a clever and efficient strategy: it treats a mini-batch of graphs as a single, large graph with disconnected components. Instead of trying to force graphs into a rigid tensor structure, this approach combines them into a larger graph that can be processed in a single forward pass.
The core idea is to combine a list of small graphs into one large Data object. Let's say we have a batch of N graphs. PyG creates a single graph that contains all nodes and edges from these N graphs. Since there are no edges connecting the individual graphs, they remain as separate components within the larger structure. This allows GNN layers to perform message passing correctly, as information will only flow within the boundaries of each original subgraph.
This process involves three main steps:
batch Attribute: A new attribute, batch, is created. This is a column vector that maps each node to its original graph in the mini-batch. For example, all nodes from the first graph get an index of 0, all nodes from the second get an index of 1, and so on.The following diagram illustrates how two small graphs are combined into a single batched graph.
In this batched graph, the node indices for Graph 2 are shifted by 3 (the number of nodes in Graph 1). The message passing operations will naturally be confined to their original subgraphs because no edges exist between the blue and green node groups.
DataLoaderFortunately, you do not need to perform this batching procedure manually. PyTorch Geometric provides a DataLoader class, similar to the one in standard PyTorch, that handles this entire process automatically.
You simply pass a list of Data objects (your dataset) to the DataLoader, and it will yield Data objects that represent entire mini-batches.
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
# Load a dataset of many small graphs
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
# Create a DataLoader
loader = DataLoader(dataset, batch_size=32, shuffle=True)
# Iterate over the batches
for batch in loader:
# batch is a single Data object representing 32 graphs
print(batch)
# > DataBatch(edge_index=[2, 2166], x=[597, 21], y=[32], batch=[597], ptr=[33])
# batch.num_graphs gives the number of graphs in the batch
print(f"Number of graphs in batch: {batch.num_graphs}")
# > Number of graphs in batch: 32
The batch attribute within the yielded DataBatch object is crucial for performing graph-level operations. For instance, in a graph classification task, you would use this batch vector to perform a pooling operation (like global_mean_pool) that aggregates node embeddings for each graph in the batch separately.
The batching strategy described here is designed for datasets that contain many small-to-medium-sized graphs, such as a collection of molecules. It is not suitable for a single, massive graph like a social network. Processing a graph with millions of nodes in one pass is often infeasible due to memory constraints.
For the single large graph setting, a different technique called neighborhood sampling is used. Instead of batching entire graphs, we create mini-batches by sampling a fixed number of neighbors for a set of root nodes. This is the approach used by models like GraphSAGE, which allows them to train on massive graphs that cannot fit into memory. PyTorch Geometric also provides tools for this, such as the NeighborLoader, but that is a more advanced topic outside the initial introduction.
Was this section helpful?
torch_geometric.data.Data, torch_geometric.data.Batch, and torch_geometric.loader.DataLoader, which are central to efficient graph batching in PyG.© 2026 ApX Machine LearningEngineered with