Let's put the theory from this chapter into practice. We'll take a standard Graph Neural Network implementation and apply several optimization techniques discussed earlier. The goal isn't just to make the code run faster or use less memory, but to understand how and why these optimizations work, equipping you to apply them to your own complex GNN projects.
We will focus on a common task: semi-supervised node classification. We'll start with a baseline GCN model implemented using PyTorch Geometric (PyG), identify performance characteristics, and then incrementally apply optimizations, measuring the impact at each step.
First, let's define our starting point. We'll use the Cora dataset and a simple two-layer GCN model. Assume you have PyTorch, PyG, and the Cora dataset readily available.
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import time
import torch.cuda.amp as amp # For Automatic Mixed Precision
# Load the dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
# Define the baseline GCN model
class BaselineGCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
# No dropout in baseline for simplicity of profiling compute/memory
# x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
model = BaselineGCN(dataset.num_node_features, 16, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# --- Baseline Training Loop ---
def train_baseline():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
def profile_run(train_func, run_name="Run", epochs=50):
print(f"--- Profiling: {run_name} ---")
start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats(device)
start_mem = torch.cuda.max_memory_allocated(device)
for epoch in range(epochs):
loss = train_func()
# In a real scenario, add validation/testing steps
# print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
end_time = time.time()
total_time = end_time - start_time
if torch.cuda.is_available():
end_mem = torch.cuda.max_memory_allocated(device)
peak_mem_increase_mib = (end_mem - start_mem) / 1024**2
print(f"Peak GPU Memory Increase: {peak_mem_increase_mib:.2f} MiB")
else:
peak_mem_increase_mib = 0 # Placeholder if no GPU
print("GPU not available, memory profiling skipped.")
print(f"Total Training Time ({epochs} epochs): {total_time:.3f} seconds")
print(f"Average Time per Epoch: {total_time / epochs:.4f} seconds")
print("-" * (20 + len(run_name)))
return total_time / epochs, peak_mem_increase_mib
# Profile the baseline
baseline_avg_time, baseline_peak_mem = profile_run(train_baseline, "Baseline GCN")
This baseline performs full-graph training. For datasets like Cora, this is often feasible. However, as graphs scale, this approach becomes intractable due to memory constraints (loading the entire graph, features, and intermediate activations) and computational cost.
For larger graphs, full-batch training is often impossible. The standard approach is mini-batching, where we process smaller subgraphs at each step. PyG's NeighborLoader
implements neighbor sampling, as popularized by GraphSAGE.
Let's modify our training to use NeighborLoader
.
from torch_geometric.loader import NeighborLoader
# Create a NeighborLoader
# We sample 2 layers deep, with 10 neighbors at the first hop and 5 at the second.
# Adjust batch_size and num_neighbors based on your hardware and graph.
train_loader = NeighborLoader(
data,
num_neighbors=[10, 5], # Number of neighbors to sample for each layer
batch_size=128, # Process 128 training nodes per batch
input_nodes=data.train_mask, # Sample neighborhoods starting from training nodes
shuffle=True
)
# --- Modified Model for Mini-batching ---
# The model architecture itself doesn't strictly need changes for NeighborLoader,
# but the forward pass receives Batch objects instead of the full Data object.
class SampledGCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
# Note: PyG's GCNConv handles the sampled structure correctly.
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index, size):
# 'size' is needed by GCNConv in NeighborLoader context to know the
# dimensions of the bipartite graph corresponding to the sampled neighborhood.
x = self.conv1(x, edge_index, size=size[0]) # size for first layer
x = F.relu(x)
x = self.conv2(x, edge_index, size=size[1]) # size for second layer
return F.log_softmax(x, dim=1)
# Re-initialize model and optimizer
model_sampled = SampledGCN(dataset.num_node_features, 16, dataset.num_classes).to(device)
optimizer_sampled = torch.optim.Adam(model_sampled.parameters(), lr=0.01, weight_decay=5e-4)
# --- Sampled Training Loop ---
def train_sampled():
model_sampled.train()
total_loss = 0
# Process mini-batches
for batch in train_loader:
batch = batch.to(device)
optimizer_sampled.zero_grad()
# The forward pass now takes the batch features, edge_index,
# and the bipartite graph sizes provided by the loader.
# The output 'out' corresponds only to the nodes in the batch (first batch_size nodes).
out = model_sampled(batch.x, batch.edge_index, batch.size)
loss = F.nll_loss(out, batch.y[:batch.batch_size]) # Use labels of central nodes
loss.backward()
optimizer_sampled.step()
total_loss += loss.item() * batch.batch_size
return total_loss / data.train_mask.sum().item() # Average loss over all training nodes
# Profile the sampled version
# Note: Epoch time includes iterating through all mini-batches.
sampled_avg_time, sampled_peak_mem = profile_run(train_sampled, "Neighbor Sampling GCN")
Observations:
Modern GPUs have specialized hardware (Tensor Cores) that accelerates computations using lower precision formats like FP16 (half-precision). PyTorch's Automatic Mixed Precision (AMP) utilities allow us to leverage this with minimal code changes, often providing speedups and reducing memory usage.
We'll apply AMP to the mini-batching setup, as it's more representative of where these optimizations yield the largest gains.
# Re-initialize model and optimizer (important for grad scaler)
model_amp = SampledGCN(dataset.num_node_features, 16, dataset.num_classes).to(device)
optimizer_amp = torch.optim.Adam(model_amp.parameters(), lr=0.01, weight_decay=5e-4)
# Create a gradient scaler for loss scaling
scaler = amp.GradScaler(enabled=torch.cuda.is_available())
# --- AMP Training Loop ---
def train_amp():
model_amp.train()
total_loss = 0
for batch in train_loader: # Reuse the same loader
batch = batch.to(device)
optimizer_amp.zero_grad()
# Use autocast context manager
# Operations inside this context run in lower precision where beneficial
with amp.autocast(enabled=torch.cuda.is_available()):
out = model_amp(batch.x, batch.edge_index, batch.size)
loss = F.nll_loss(out, batch.y[:batch.batch_size])
# Scale the loss before backward pass
scaler.scale(loss).backward()
# Unscale gradients and update model weights
scaler.step(optimizer_amp)
# Update the scaler for the next iteration
scaler.update()
total_loss += loss.item() * batch.batch_size
return total_loss / data.train_mask.sum().item()
# Profile the AMP version
amp_avg_time, amp_peak_mem = profile_run(train_amp, "Sampled GCN + AMP")
Observations:
GradScaler
helps mitigate this by dynamically scaling the loss. Accuracy should generally be comparable to FP32 training.PyTorch 2.0 introduced torch.compile
, a feature that can significantly accelerate model execution by converting Python code into optimized graph representations and leveraging backend compilers like TorchInductor. Applying it is often straightforward.
# Re-initialize model and optimizer
model_compiled_base = BaselineGCN(dataset.num_node_features, 16, dataset.num_classes).to(device)
optimizer_compiled_base = torch.optim.Adam(model_compiled_base.parameters(), lr=0.01, weight_decay=5e-4)
# Compile the baseline model
# Use default mode first, explore others like 'reduce-overhead' or 'max-autotune' later
compiled_model = torch.compile(model_compiled_base)
# --- Compiled Baseline Training Loop ---
def train_compiled_baseline():
compiled_model.train() # Use the compiled model
optimizer_compiled_base.zero_grad()
out = compiled_model(data.x, data.edge_index) # Pass data to compiled model
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer_compiled_base.step()
return loss.item()
# Profile the compiled baseline version
# Note: The first few runs might be slower due to compilation overhead.
# Run for more epochs to amortize this cost.
compiled_avg_time, compiled_peak_mem = profile_run(train_compiled_baseline, "Compiled Baseline GCN", epochs=100) # Longer run
# ---- Optional: Compile the Sampled+AMP model ----
# model_compiled_amp = SampledGCN(...) # Initialize
# optimizer_compiled_amp = Adam(...)
# compiled_model_amp = torch.compile(model_compiled_amp)
# scaler_compiled = amp.GradScaler(...)
# def train_compiled_amp():
# ... # Use compiled_model_amp, autocast, scaler
# profile_run(train_compiled_amp, "Compiled Sampled GCN + AMP", epochs=100)
Observations:
torch.compile
can lead to substantial speedups, especially for models dominated by standard PyTorch operations and on newer hardware. The benefit for GNNs depends on how much of the execution time is spent within compilable kernels versus custom C++/CUDA extensions used by PyG/DGL layers (which might already be highly optimized). You might see less dramatic speedups compared to pure CNNs/Transformers, but improvements are still common.Let's consolidate our findings. The exact numbers will vary greatly depending on your hardware (CPU, GPU, memory), software versions, and the specific dataset.
{"layout": {"title": "GCN Optimization Comparison (Cora Dataset)", "xaxis": {"title": "Metric"}, "yaxis": {"title": "Value (Lower is Better)", "type": "log"}, "barmode": "group", "legend": {"title": "Configuration"}, "colorway": ["#4263eb", "#1098ad", "#7048e8", "#37b24d"]}, "data": [{"type": "bar", "name": "Avg. Epoch Time (s)", "x": ["Baseline", "Sampled", "Sampled+AMP", "Compiled"], "y": [baseline_avg_time, sampled_avg_time, amp_avg_time, compiled_avg_time]}, {"type": "bar", "name": "Peak GPU Mem Incr (MiB)", "x": ["Baseline", "Sampled", "Sampled+AMP", "Compiled"], "y": [baseline_peak_mem, sampled_peak_mem, amp_peak_mem, compiled_peak_mem]}]}
Comparison of average training time per epoch and peak GPU memory increase across different optimization strategies applied to a GCN model on the Cora dataset. Note the logarithmic scale on the y-axis.
Interpreting Results:
This practical demonstrated how to apply common optimization techniques to a GCN implementation using PyG:
NeighborLoader
): Essential for scalability to large graphs by reducing memory footprint.torch.cuda.amp
): Reduces memory and often accelerates training on supported hardware.torch.compile
): Can accelerate Python and PyTorch code execution through graph optimization and specialized backends.Remember that optimization is an iterative process:
By understanding and applying these techniques, you can build GNN models that are not only accurate but also efficient and scalable to handle complex, large-scale graph data.
© 2025 ApX Machine Learning