While PyTorch offers fundamental building blocks like torch.nn.Linear
or torch.nn.Conv2d
, and containers like torch.nn.Sequential
, real-world applications often benefit from encapsulating more complex or specialized logic into reusable components. Extending torch.nn.Module
is the standard PyTorch mechanism for creating these custom layers or network segments. This approach promotes modularity, code organization, and reusability, making it easier to manage intricate model architectures. It allows you to define not just the layers but also the specific forward computation logic, including control flow, interactions between sub-components, and integration with custom operations discussed elsewhere in this chapter.
At its heart, a custom module is a Python class inheriting from torch.nn.Module
. The two most important methods you will typically override are:
__init__(self, ...)
: The constructor. This is where you define and initialize the module's components:
nn.Module
classes (including standard PyTorch layers or other custom modules).torch.nn.Parameter
.self.register_buffer()
.super().__init__()
at the beginning of your __init__
method. This ensures the base nn.Module
class initializes correctly, setting up internal structures needed for parameter tracking, device movement, and state saving.forward(self, ...)
: This method defines the computation performed by the module. It takes input tensors (and potentially other arguments) and returns output tensors. You use the submodules, parameters, and buffers defined in __init__
within this method to implement the desired logic. PyTorch's dynamic computation graph is built based on the operations performed within forward
.
Here's a skeletal structure:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyCustomModule(nn.Module):
def __init__(self, input_features, output_features, hidden_units):
super().__init__() # Essential first step
# Define submodules (layers)
self.layer1 = nn.Linear(input_features, hidden_units)
self.activation = nn.ReLU()
self.layer2 = nn.Linear(hidden_units, output_features)
# Define learnable parameters directly (if needed)
# Example: A learnable scaling factor
self.scale = nn.Parameter(torch.randn(1))
# Define non-learnable state (buffers)
# Example: A counter for forward passes (demonstration only)
self.register_buffer('forward_count', torch.zeros(1, dtype=torch.long))
def forward(self, x):
# Define the computation flow using initialized components
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
# Use the custom parameter
x = x * self.scale
# Update the buffer (ensure device compatibility if needed)
# Note: Direct modification like this might not be typical
# in standard training loops but illustrates buffer usage.
self.forward_count += 1
return x
# Example Usage:
input_dim = 64
output_dim = 10
hidden_dim = 128
model = MyCustomModule(input_dim, output_dim, hidden_dim)
print(model)
# Test forward pass
dummy_input = torch.randn(4, input_dim) # Batch size 4
output = model(dummy_input)
print("Output shape:", output.shape)
print("Forward count:", model.forward_count)
# Parameters and buffers are tracked
for name, param in model.named_parameters():
print(f"Parameter: {name}, Shape: {param.shape}")
for name, buf in model.named_buffers():
print(f"Buffer: {name}, Value: {buf}")
Creating effective custom modules involves more than just inheriting from nn.Module
. Consider these practices:
__init__
primarily to define the components (submodules, parameters, buffers). Avoid performing significant computation here. All components defined as attributes that are nn.Module
instances or nn.Parameter
instances are automatically registered. This means they appear in model.parameters()
, their state is saved in model.state_dict()
, and methods like model.to(device)
correctly move them.register_buffer
for State: For tensors that are part of the module's state but should not be updated by the optimizer (like running statistics or fixed constants), use self.register_buffer('buffer_name', tensor)
. Buffers are correctly handled by state_dict
and device placement methods (.to()
, .cuda()
, .cpu()
), unlike plain Python attributes holding tensors.forward
Defines Computation: The forward
method encapsulates the module's runtime logic. It can contain any valid Python code, including conditional statements (if
/else
) and loops (for
), enabling dynamic computational behavior. Ensure tensors required for gradient computation are created or manipulated within forward
or passed as arguments.nn.Sequential
, nn.ModuleList
, and nn.ModuleDict
work seamlessly with custom modules.A conceptual view of composing a network using a custom
nn.Module
(MyCustomBlock
) alongside standard PyTorch layers. The custom block encapsulates internal layers and logic.
Let's implement a basic scaled dot-product attention mechanism, a fundamental component in Transformers, as a custom module. This demonstrates defining parameters (implicitly within nn.Linear
) and implementing specific mathematical operations in forward
.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SimpleScaledDotProductAttention(nn.Module):
""" Computes simple scaled dot-product attention. """
def __init__(self, d_model, d_k, dropout_p=0.1):
"""
Args:
d_model (int): Dimension of the input embeddings.
d_k (int): Dimension of keys and queries (often d_model // num_heads).
dropout_p (float): Dropout probability.
"""
super().__init__()
self.d_k = d_k
# Linear layers to project inputs to Q, K, V spaces
self.query_proj = nn.Linear(d_model, d_k)
self.key_proj = nn.Linear(d_model, d_k)
self.value_proj = nn.Linear(d_model, d_k) # Often d_v = d_k
self.dropout = nn.Dropout(dropout_p)
def forward(self, query, key, value, mask=None):
"""
Args:
query (torch.Tensor): Query tensor, shape (Batch, Seq_len_q, d_model).
key (torch.Tensor): Key tensor, shape (Batch, Seq_len_k, d_model).
value (torch.Tensor): Value tensor, shape (Batch, Seq_len_v, d_model).
Typically Seq_len_k == Seq_len_v.
mask (torch.Tensor, optional): Mask tensor to prevent attention to
certain positions (e.g., padding).
Shape (Batch, Seq_len_q, Seq_len_k).
Values should be 0 for attended positions, -inf for masked.
Returns:
torch.Tensor: Output tensor after attention, shape (Batch, Seq_len_q, d_k).
torch.Tensor: Attention weights, shape (Batch, Seq_len_q, Seq_len_k).
"""
# 1. Project inputs
Q = self.query_proj(query) # (B, Seq_q, d_k)
K = self.key_proj(key) # (B, Seq_k, d_k)
V = self.value_proj(value) # (B, Seq_v, d_k)
# 2. Calculate attention scores (QK^T / sqrt(d_k))
# K.transpose(-2, -1) results in shape (B, d_k, Seq_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# scores shape: (B, Seq_q, Seq_k)
# 3. Apply mask (if provided)
if mask is not None:
# Ensure mask has compatible dimensions, might need unsqueezing
# Example: if mask is (B, Seq_k), add dimension for Seq_q broadasting
# mask = mask.unsqueeze(1) # -> (B, 1, Seq_k)
scores = scores.masked_fill(mask == 0, float('-inf')) # Common convention: 0 means mask
# 4. Apply softmax to get attention weights
attn_weights = F.softmax(scores, dim=-1) # (B, Seq_q, Seq_k)
# 5. Apply dropout to attention weights
attn_weights = self.dropout(attn_weights)
# 6. Compute weighted sum of Values
output = torch.matmul(attn_weights, V) # (B, Seq_q, Seq_k) @ (B, Seq_v, d_k) -> (B, Seq_q, d_k)
# Assumes Seq_k == Seq_v
return output, attn_weights
# Example Usage:
batch_size = 4
seq_len = 10
embed_dim = 128
key_dim = 64
attention_module = SimpleScaledDotProductAttention(d_model=embed_dim, d_k=key_dim)
# Create dummy inputs (usually the same tensor for self-attention)
q_input = torch.randn(batch_size, seq_len, embed_dim)
k_input = torch.randn(batch_size, seq_len, embed_dim)
v_input = torch.randn(batch_size, seq_len, embed_dim)
output, weights = attention_module(q_input, k_input, v_input)
print("Attention Output Shape:", output.shape) # Expected: (4, 10, 64)
print("Attention Weights Shape:", weights.shape) # Expected: (4, 10, 10)
This example encapsulates the attention logic within a single module, making it easy to integrate into larger models like a Transformer encoder or decoder layer.
torch.nn.utils.rnn.pack_padded_sequence
or adaptive pooling layers can also be relevant depending on the application.nn.Module
provides a hook mechanism (register_forward_hook
, register_backward_hook
, register_forward_pre_hook
) that allows you to execute custom code before or after the forward
pass, or during the backward
pass, without modifying the module's core forward
code. Hooks are useful for debugging, visualization, or implementing certain normalization techniques.nn.Module
's forward
method is the natural place to call specialized C++ or CUDA extensions (covered in other sections of this chapter) or custom autograd.Function
instances when performance or specific gradient calculations necessitate them. The module structure neatly encapsulates the interaction between standard PyTorch components and these custom backends.By mastering the extension of torch.nn.Module
, you gain the flexibility to implement virtually any network architecture or component, structuring your code in a clean, reusable, and maintainable manner, which is indispensable for tackling advanced deep learning projects.
© 2025 ApX Machine Learning