Alright, let's put theory into practice. In this section, we'll build and train a Graph Neural Network specifically designed for node classification on a heterogeneous graph. We'll leverage the concepts discussed earlier, particularly focusing on handling multiple node and edge types. We will use the DBLP computer science bibliography dataset and implement a Heterogeneous Attention Network (HAN) model using PyTorch Geometric (PyG). This practical exercise assumes you have PyTorch and PyG installed and are comfortable with their basic usage.
Our goal is to predict the research area (e.g., Database, Data Mining, AI) for authors in the DBLP dataset based on their publication history, co-authorship network, and the terms associated with their papers.
First, let's load the DBLP dataset provided by PyG. This dataset naturally represents a heterogeneous graph.
import torch
import torch.nn.functional as F
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HANConv, Linear
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
# Load the dataset
dataset = DBLP(root='./data/DBLP')
data = dataset[0]
# Apply normalization and transformation
# Add identity matrix for self-loops for numerical stability, then normalize features
# Note: This transform might modify the data object in place or return a new one depending on PyG version.
# We explicitly assign it back for clarity.
transform = T.Compose([
T.NormalizeFeatures(),
T.ToUndirected() # Ensure graph is undirected for simpler relation handling
])
data = transform(data)
print("DBLP Dataset Overview:")
print(data)
print("\nNode Types:", data.node_types)
print("Edge Types:", data.edge_types)
# Example: Accessing author features and labels
print("\nAuthor Node Features Shape:", data['author'].x.shape)
print("Author Node Labels Shape:", data['author'].y.shape)
print("Number of Classes:", dataset.num_classes)
You'll notice the HeteroData
object nicely organizes features, labels, and edge indices by their respective types. For DBLP, we typically have node types like 'author', 'paper', 'term', and 'venue', and edge types representing relationships like 'author' writes 'paper', 'paper' cites 'paper', 'paper' uses 'term', etc. The specific structure might vary slightly depending on the PyG version and dataset preprocessing. Our task focuses on classifying the 'author' nodes.
Understanding the relationships between different node types is important. We can visualize the schema of our heterogeneous graph.
The schema of the DBLP heterogeneous graph, showing different node types (author, paper, term, venue) and the relationships (edge types) between them.
ToUndirected()
in PyG often creates reverse edge types.
The HAN architecture is well-suited for heterogeneous graphs. It uses attention mechanisms at two levels:
A meta-path is a sequence of node types connected by edge types, like Author → Paper → Author (APA) or Author → Paper → Venue → Paper → Author (APVPA).
PyG's HANConv
layer simplifies implementing this. It requires specifying the target node type and the meta-paths relevant to it.
class HAN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, metadata, heads=8):
"""
Initializes the HAN model.
Args:
in_channels (int or dict): Size of input features. If int, assumes all node types
have the same feature size. If dict, maps node types
to their feature sizes.
hidden_channels (int): Size of the hidden embeddings.
out_channels (int): Number of output classes.
metadata (tuple): Metadata tuple containing node types and edge types,
obtained from data.metadata().
heads (int, optional): Number of attention heads. Defaults to 8.
"""
super().__init__()
# HANConv layer automatically handles multiple meta-paths based on metadata
# It performs node-level and semantic-level attention.
# We specify '-1' for in_channels to let HANConv infer input sizes
# per node type from the metadata and input data.
self.conv1 = HANConv(in_channels=-1, out_channels=hidden_channels,
metadata=metadata, heads=heads, dropout=0.6)
self.conv2 = HANConv(in_channels=hidden_channels, out_channels=out_channels,
metadata=metadata, heads=1, dropout=0.6) # Usually 1 head for final layer
def forward(self, x_dict, edge_index_dict):
"""
Forward pass of the HAN model.
Args:
x_dict (dict): Dictionary mapping node types to their feature tensors.
edge_index_dict (dict): Dictionary mapping edge types to their edge index tensors.
Returns:
torch.Tensor: Output logits for the target node type ('author').
"""
# Note: HANConv returns a dictionary of embeddings for all node types reached
# through the defined meta-paths originating from the source nodes.
x_dict = self.conv1(x_dict, edge_index_dict)
# Apply activation (optional, depends on layer implementation details)
# x_dict = {key: F.elu(x) for key, x in x_dict.items()} # Example activation
x_dict = self.conv2(x_dict, edge_index_dict)
# We only need the output for the 'author' node type for our classification task
return x_dict['author']
# Prepare model arguments
metadata = data.metadata()
hidden_channels = 128
num_classes = dataset.num_classes
# Instantiate the model
# PyG's HANConv can infer input channels if set to -1
model = HAN(in_channels=-1, hidden_channels=hidden_channels,
out_channels=num_classes, metadata=metadata, heads=8)
print("\nHAN Model Architecture:")
print(model)
# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = data.to(device) # Move data object to the same device
We define a two-layer HAN model. HANConv
conveniently takes the entire HeteroData
's feature dictionary (x_dict
) and edge index dictionary (edge_index_dict
) as input. It internally figures out the meta-paths based on the provided metadata
and computes the attention-weighted representations. Note that we specify in_channels=-1
to let HANConv
infer the input feature dimensions for each node type automatically. The output we care about is the embedding for the 'author' nodes, which we'll use for classification.
Now, let's set up the standard PyTorch training components. We'll use the Adam optimizer and Cross-Entropy loss. We need masks to select the training, validation, and test nodes (specifically for the 'author' type). These masks are often provided with the dataset.
# Check if standard masks are available, otherwise create random splits
if 'train_mask' not in data['author']:
print("\nGenerating random masks for author nodes...")
num_authors = data['author'].num_nodes
indices = torch.randperm(num_authors)
train_split = int(0.6 * num_authors)
val_split = int(0.8 * num_authors)
data['author'].train_mask = torch.zeros(num_authors, dtype=torch.bool)
data['author'].train_mask[indices[:train_split]] = True
data['author'].val_mask = torch.zeros(num_authors, dtype=torch.bool)
data['author'].val_mask[indices[train_split:val_split]] = True
data['author'].test_mask = torch.zeros(num_authors, dtype=torch.bool)
data['author'].test_mask[indices[val_split:]] = True
# Ensure masks are on the correct device
data['author'].train_mask = data['author'].train_mask.to(device)
data['author'].val_mask = data['author'].val_mask.to(device)
data['author'].test_mask = data['author'].test_mask.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
optimizer.zero_grad()
# Pass the entire dictionaries
out = model(data.x_dict, data.edge_index_dict)
# Compute loss only on training nodes of the 'author' type
mask = data['author'].train_mask
loss = criterion(out[mask], data['author'].y[mask])
loss.backward()
optimizer.step()
return float(loss)
@torch.no_grad()
def test():
model.eval()
# Pass the entire dictionaries
pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)
accs = []
# Calculate accuracy for train, validation, and test sets for 'author' nodes
for split in ['train_mask', 'val_mask', 'test_mask']:
mask = data['author'][split]
acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()
accs.append(float(acc))
return accs
print("\nStarting Training...")
for epoch in range(1, 101):
loss = train()
train_acc, val_acc, test_acc = test()
if epoch % 10 == 0:
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ',
f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
print("Training Finished.")
# Final Test Accuracy
final_train_acc, final_val_acc, final_test_acc = test()
print(f"\nFinal Performance:\n"
f" Train Accuracy: {final_train_acc:.4f}\n"
f" Validation Accuracy: {final_val_acc:.4f}\n"
f" Test Accuracy: {final_test_acc:.4f}")
In the training loop, we pass the complete x_dict
and edge_index_dict
to the model. The loss and accuracy calculations are performed only on the relevant 'author' nodes using the provided masks.
Execute the code. You'll observe the training loss decreasing and accuracies generally increasing over epochs. The final test accuracy gives an indication of how well the HAN model generalized to unseen author nodes.
The performance depends on factors like the chosen hyperparameters (hidden dimensions, learning rate, heads), the quality of node features, and the expressiveness of the meta-paths implicitly considered by HANConv
based on the graph's metadata
.
This practical example demonstrated how to:
HeteroData
.HANConv
.You can experiment further by:
hidden_channels
, lr
, heads
, dropout
).HeteroConv
with specific aggregation functions (e.g., GATConv
, GCNConv
) per relation type.Handling heterogeneous graphs is a common requirement in real-world applications, and architectures like HAN provide an effective way to model these complex relational structures.
© 2025 ApX Machine Learning