In the previous sections, you saw how libraries like PyTorch Geometric provide powerful, pre-built layers for constructing Graph Neural Networks (GNNs). While these are immensely useful, understanding how to build a GNN layer from scratch using core PyTorch operations provides deeper insight and the flexibility to implement novel or customized message-passing schemes. This practical exercise guides you through creating a simple, custom GNN layer.
The fundamental idea behind many GNN layers is message passing, where nodes iteratively aggregate information from their neighbors and update their own representations. We can break this down into two main steps for each node i:
Let's implement a basic layer that performs these steps. We'll define a layer that transforms node features using a learnable weight matrix, aggregates the transformed features from neighbors using a simple sum, and then applies an activation function.
Mathematically, for a node i, the operation can be described as:
ai=j∈N(i)∪{i}∑Whj hi′=σ(ai)Here, hj represents the feature vector of node j, W is a learnable weight matrix, N(i) is the set of neighbors of node i, and σ is a non-linear activation function (like ReLU). Note that we include the node itself (i) in the aggregation, often referred to as adding a self-loop. This ensures the node's original features are considered in the update.
First, ensure you have PyTorch imported. We'll define our custom layer as a Python class inheriting from torch.nn.Module
.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleGNNLayer(nn.Module):
"""
A basic Graph Neural Network layer implementing message passing.
Args:
in_features (int): Size of each input node feature vector.
out_features (int): Size of each output node feature vector.
"""
def __init__(self, in_features, out_features):
super(SimpleGNNLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
# Define the learnable weight matrix
self.linear = nn.Linear(in_features, out_features, bias=False)
# Initialize weights (optional but often good practice)
nn.init.xavier_uniform_(self.linear.weight)
def forward(self, x, edge_index):
"""
Defines the computation performed at every call.
Args:
x (torch.Tensor): Node features tensor with shape [num_nodes, in_features].
edge_index (torch.Tensor): Graph connectivity in COO format with shape [2, num_edges].
edge_index[0] = source nodes, edge_index[1] = target nodes.
Returns:
torch.Tensor: Updated node features tensor with shape [num_nodes, out_features].
"""
num_nodes = x.size(0)
# 1. Add self-loops to the adjacency matrix represented by edge_index
# Create tensor of node indices [0, 1, ..., num_nodes-1]
self_loops = torch.arange(0, num_nodes, device=x.device).unsqueeze(0)
self_loops = self_loops.repeat(2, 1) # Shape [2, num_nodes]
# Concatenate original edges with self-loops
edge_index_with_self_loops = torch.cat([edge_index, self_loops], dim=1)
# Extract source and target node indices
row, col = edge_index_with_self_loops
# 2. Linearly transform node features
x_transformed = self.linear(x) # Shape: [num_nodes, out_features]
# 3. Aggregate features from neighbors (including self)
# We want to sum features of source nodes (row) for each target node (col)
# Initialize output tensor with zeros
aggregated_features = torch.zeros(num_nodes, self.out_features, device=x.device)
# Use index_add_ for efficient aggregation (scatter sum)
# Adds elements from x_transformed[row] into aggregated_features at indices specified by col
# index_add_(dimension, index_tensor, tensor_to_add)
aggregated_features.index_add_(0, col, x_transformed[row])
# 4. Apply final activation function (optional)
# For this example, let's use ReLU
output_features = F.relu(aggregated_features)
return output_features
def __repr__(self):
return f'{self.__class__.__name__}({self.in_features}, {self.out_features})'
__init__
): We define a single nn.Linear
layer. This layer will apply the learnable weight transformation W to the input node features. We set bias=False
for simplicity, consistent with some GNN formulations like the basic GCN. Weight initialization using nn.init.xavier_uniform_
helps stabilize training.forward
): This is where the message passing logic resides.
edge_index
. This ensures that when aggregating neighbor features for a node, the node's own transformed features are also included. We create an edge index representing edges from each node to itself and concatenate it with the original edge_index
.self.linear
) to all node features x
simultaneously.x_transformed[row]
) for each target node (col
). torch.index_add_
is a highly efficient way to perform this "scatter-add" operation. It takes the tensor to accumulate into (aggregated_features
), the dimension along which to index (0
for nodes), the indices to add at (col
, the target nodes), and the values to add (x_transformed[row]
, the transformed features of the source nodes).F.relu
) is applied element-wise.Here's a small graph visualization to illustrate the edge_index
format and the idea of neighbors:
For the graph above, a possible
edge_index
(representing directed edges for message passing, assuming undirected original edges means messages pass both ways) could be:tensor([[0, 0, 1, 2, 1, 2, 3, 3], [1, 2, 0, 0, 3, 3, 1, 2]])
. The first row contains source nodes, the second contains target nodes. When aggregating for node 3, we'd look at messages from source nodes 1 and 2.
Now, let's see how to use this SimpleGNNLayer
. We need some sample node features and an edge_index
.
# Example Usage
# Define graph data
num_nodes = 4
num_features = 8
out_layer_features = 16
# Node features (random)
x = torch.randn(num_nodes, num_features)
# Edge index representing connections (e.g., 0->1, 0->2, 1->3, 2->3 and vice-versa for undirected)
edge_index = torch.tensor([
[0, 0, 1, 2, 1, 2, 3, 3], # Source nodes
[1, 2, 0, 0, 3, 3, 1, 2] # Target nodes
], dtype=torch.long)
# Instantiate the layer
gnn_layer = SimpleGNNLayer(in_features=num_features, out_features=out_layer_features)
print(f"Instantiated layer: {gnn_layer}")
# Pass data through the layer
output_node_features = gnn_layer(x, edge_index)
# Check output shape
print(f"\nInput node features shape: {x.shape}")
print(f"Edge index shape: {edge_index.shape}")
print(f"Output node features shape: {output_node_features.shape}")
# Verify output shape matches expectation: [num_nodes, out_features]
assert output_node_features.shape == (num_nodes, out_layer_features)
print("\nSuccessfully passed data through the custom GNN layer.")
# Display first few output features for node 0
print(f"Output features for node 0 (first 5 dims): {output_node_features[0, :5].detach().numpy()}")
This example demonstrates creating random node features and a sample edge_index
, instantiating our SimpleGNNLayer
, and performing a forward pass. The output shape [num_nodes, out_features]
confirms the layer operates as expected, producing new embeddings for each node based on its neighborhood.
This simple layer serves as a foundation. You could extend it in numerous ways:
index_add_
(sum aggregation) with mean or max aggregation. Mean aggregation often requires knowing the degree of each node.forward
pass to accept and utilize edge features, potentially incorporating them into the message calculation before aggregation.nn.Linear
layer or add it after aggregation.Building custom layers like this is a valuable skill. It allows you to implement cutting-edge GNN architectures directly from research papers or tailor message-passing schemes precisely to your problem's needs, going beyond the standard offerings of libraries when necessary. This same principle of building custom nn.Module
components applies when implementing unique mechanisms within Transformers, Normalizing Flows, or other advanced architectures covered in this course.
© 2025 ApX Machine Learning