Okay, let's translate the theory of Graph Attention Networks (GAT) into a practical implementation. Earlier in this chapter, we discussed how GAT layers use self-attention mechanisms to assign different importance weights to different nodes within a neighborhood during aggregation. This allows the model to focus on more relevant neighbors for each node. We saw the core components: a shared linear transformation, an attention mechanism a, and the application of softmax for normalization.
Now, we will build a GAT layer. We'll use PyTorch and the PyTorch Geometric (PyG) library, leveraging its MessagePassing
base class which simplifies the implementation of message passing GNNs by handling the propagation logic. This approach strikes a balance between understanding the fundamental operations and using efficient library tools.
Ensure you have PyTorch and PyTorch Geometric installed. You should be comfortable with PyTorch's nn.Module
and basic PyG concepts like Data
objects and the MessagePassing
interface.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax
from torch_geometric.typing import Adj, OptTensor, PairTensor
Recall the steps for computing the output features hi′ for a node i in a single GAT layer:
We'll implement a GAT layer supporting multi-head attention using the MessagePassing
base class. This class handles the aggregation process (step 4) based on the messages computed in the message
function.
class GATLayer(MessagePassing):
"""Implementation of a single GAT layer with multi-head attention."""
def __init__(self, in_features: int, out_features: int, heads: int = 1,
concat: bool = True, negative_slope: float = 0.2,
dropout: float = 0.0, add_self_loops: bool = True,
bias: bool = True, **kwargs):
# Use 'add' aggregation for weighted sum.
kwargs.setdefault('aggr', 'add')
super().__init__(node_dim=0, **kwargs) # node_dim=0 indicates operating on node features
self.in_features = in_features
self.out_features = out_features
self.heads = heads
self.concat = concat
self.negative_slope = negative_slope
self.dropout = dropout
self.add_self_loops = add_self_loops
# Output dimension per head. If concatenating, divide total out_features by heads.
if concat:
assert out_features % heads == 0
self.head_dim = out_features // heads
else:
self.head_dim = out_features
# Step 1: Linear transformation (W) applied to all nodes.
# This is implemented as K independent linear layers (one per head).
self.lin = nn.Linear(in_features, self.heads * self.head_dim, bias=False)
# Step 2: Attention mechanism parameters 'a'.
# We use two weight vectors (a_l, a_r) for source and target nodes
# transformed features, concatenated implicitly later.
# Each size [1, heads, head_dim]
self.att_src = nn.Parameter(torch.Tensor(1, heads, self.head_dim))
self.att_dst = nn.Parameter(torch.Tensor(1, heads, self.head_dim))
if bias and concat:
self.bias = nn.Parameter(torch.Tensor(out_features))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(self.head_dim))
else:
self.register_parameter('bias', None)
self._alpha = None # To store attention weights for potential inspection
self.reset_parameters()
def reset_parameters(self):
# Initialize weights similar to the original GAT paper and PyG's GATConv
nn.init.xavier_uniform_(self.lin.weight)
nn.init.xavier_uniform_(self.att_src)
nn.init.xavier_uniform_(self.att_dst)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x: torch.Tensor, edge_index: Adj,
size: tuple[int, int] | None = None, return_attention_weights: bool = False):
"""
Forward pass of the GAT layer.
Args:
x (Tensor or PairTensor): Node features (N, in_features) or ((N, F_in), (M, F_in)).
edge_index (Adj): Graph connectivity (2, E).
size (tuple, optional): Size of the bipartite graph (N, M).
return_attention_weights (bool): If True, also return attention coefficients.
"""
# Ensure x is a tensor; handle bipartite graphs later if needed
if isinstance(x, torch.Tensor):
x_l: OptTensor = x
x_r: OptTensor = x
else: # Basic handling for bipartite case PairTensor
x_l, x_r = x
assert x_l is not None
num_nodes = x_l.size(0) # Assuming N = M if not bipartite
# Step 1: Apply linear transformation(s). Project features for all heads.
# Result shape: [N, heads * head_dim]
z = self.lin(x_l)
z = z.view(-1, self.heads, self.head_dim) # Shape: [N, heads, head_dim]
# Add self-loops for nodes to attend to themselves (optional but common).
if self.add_self_loops:
if isinstance(edge_index, torch.Tensor):
num_nodes = x_l.size(0)
if x_r is not None: # Bipartite case
num_nodes = (x_l.size(0), x_r.size(0))
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes, fill_value='mean')
# Note: add_self_loops for SparseTensor needs separate handling if used.
# --- Start Message Passing ---
# Step 2 & 3: Compute attention coefficients and normalize.
# Step 4: Aggregate features.
# The propagate method orchestrates calls to message(), aggregate(), and update().
# We pass the transformed features 'z' which will be used in message().
out = self.propagate(edge_index, x=(z, z), size=size, # Pass z for both source (j) and target (i)
return_attention_weights=return_attention_weights)
# --- End Message Passing ---
# Step 5: Apply final transformations (concat/average, bias, activation).
if self.concat:
# Reshape from [N, heads, head_dim] to [N, heads * head_dim]
out = out.view(-1, self.heads * self.head_dim)
else:
# Average across heads: [N, heads, head_dim] -> [N, head_dim]
out = out.mean(dim=1)
if self.bias is not None:
out += self.bias
if return_attention_weights:
return out, self._alpha
else:
return out
def message(self, x_j: torch.Tensor, x_i: torch.Tensor, # x_j are features of source nodes, x_i features of target nodes for edges
index: torch.Tensor, # The edge indices (for softmax normalization)
ptr: OptTensor, # Pointers for sparse softmax (if using CSR format)
size_i: int | None, # Number of target nodes
return_attention_weights: bool) -> torch.Tensor:
"""
Computes messages from node j to node i for each edge (j, i).
This function implements steps 2 & 3 (attention calculation and normalization).
Args:
x_j (Tensor): Features of source nodes for edges. Shape: [E, heads, head_dim]
x_i (Tensor): Features of target nodes for edges. Shape: [E, heads, head_dim]
index (Tensor): Target node indices for each edge. Shape: [E]
ptr (OptTensor): Optional CSR pointers.
size_i (int): Number of target nodes.
return_attention_weights (bool): Flag passed from forward.
Returns:
Tensor: Messages passed along edges. Shape: [E, heads, head_dim]
"""
# Step 2: Calculate attention scores e_ij.
# Calculate component for source (j) and target (i) nodes separately
alpha_src = (x_j * self.att_src).sum(dim=-1) # Shape: [E, heads]
alpha_dst = (x_i * self.att_dst).sum(dim=-1) # Shape: [E, heads]
# Combine them
alpha = alpha_src + alpha_dst # Shape: [E, heads]
# Apply LeakyReLU activation
alpha = F.leaky_relu(alpha, self.negative_slope)
# Step 3: Normalize attention scores using softmax.
# The `softmax` utility function from PyG handles sparse softmax correctly.
# It normalizes scores per target node (index `i`).
alpha = softmax(alpha, index, ptr, size_i) # Shape: [E, heads]
# Store attention weights if needed for analysis.
if return_attention_weights:
self._alpha = alpha
# Apply dropout to attention weights (a common practice).
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
# Step 4 (part 1): Weight features by attention and create messages.
# The aggregation ('add') specified in __init__ will sum these messages per node.
# Reshape alpha to [E, heads, 1] for broadcasting.
message = x_j * alpha.unsqueeze(-1) # Shape: [E, heads, head_dim]
return message
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_features}, '
f'{self.out_features}, heads={self.heads}, concat={self.concat})')
Flow of information within a single-head GAT layer during the forward pass. Multi-head attention involves parallel execution of this flow with independent parameters, followed by concatenation or averaging.
Now let's see how to use this GATLayer
. We'll create a simple graph and pass it through the layer.
# Example Usage
# Create some dummy data: 5 nodes, 3 features each
num_nodes = 5
in_channels = 3
out_channels_final = 16 # Desired final output dimension
num_heads = 4 # Number of attention heads
x = torch.randn(num_nodes, in_channels)
# Define edges (source -> target): 0->1, 0->2, 1->2, 1->3, 2->4, 3->4
edge_index = torch.tensor([[0, 0, 1, 1, 2, 3],
[1, 2, 2, 3, 4, 4]], dtype=torch.long)
# Instantiate the layer
# Note: If concat=True, out_features must be divisible by heads.
# Here, out_features=16, heads=4 -> head_dim = 16/4 = 4.
gat_layer = GATLayer(in_features=in_channels,
out_features=out_channels_final,
heads=num_heads,
concat=True, # Concatenate head outputs
dropout=0.1) # Apply dropout during training
# Perform a forward pass
# During inference or evaluation, set model.eval() to disable dropout
gat_layer.train() # Set to training mode for dropout
output_features = gat_layer(x, edge_index)
# Check the output shape
# Expected: [num_nodes, out_channels_final] = [5, 16]
print("Input features shape:", x.shape)
print("Output features shape:", output_features.shape)
# You can also retrieve attention weights
gat_layer.eval() # Disable dropout for inspection
output_features, attention_weights = gat_layer(x, edge_index, return_attention_weights=True)
# attention_weights shape will be approximately [E + N_loops, heads]
# where E is number of original edges, N_loops is number of nodes (if add_self_loops=True)
print("Attention weights shape:", attention_weights.shape)
# The edge_index used for attention calculation includes self-loops if add_self_loops=True
print("Edge index used for attention (with potential self-loops):", gat_layer.edge_index_prop)
This implementation provides a single GAT layer. To build a complete GNN model, you would typically:
GATLayer
instances, often with activation functions (like ELU or ReLU) and potentially dropout between layers. The input features for subsequent layers would be the output features of the previous one. Be mindful that the output dimension changes based on the concat
setting. If the final layer uses concat=True
, its out_features
will be the model's final node embedding dimension. If it uses concat=False
(averaging), its out_features
defines the final dimension directly.This practical exercise demonstrates how the theoretical concepts of GAT translate into code using common libraries. By understanding the message
and propagate
mechanisms within PyG's MessagePassing
, you can implement various GNN architectures effectively. Remember that details like initialization, activation functions, dropout rates, and the number of heads are hyperparameters that often require tuning for optimal performance on specific tasks.
Was this section helpful?
© 2025 ApX Machine Learning