In PyTorch, constructing neural networks revolves around a central concept: the torch.nn.Module
. Think of nn.Module
as the foundational blueprint or base class from which all neural network models, layers, and even complex composite structures are built. It provides a standardized way to encapsulate model parameters, helper functions for managing these parameters (like moving them between CPU and GPU), and the logic defining how input data flows through the network.
Whenever you define a custom neural network in PyTorch, you will typically do so by creating a Python class that inherits from nn.Module
. This inheritance provides your custom class with a significant amount of built-in functionality essential for deep learning workflows.
nn.Module
At its heart, using nn.Module
involves implementing two primary methods within your custom class:
__init__(self)
: The constructor. This is where you define and initialize the components of your network, such as layers (convolutional, linear, etc.), activation functions, or even other nn.Module
instances (submodules). These components are typically assigned as attributes of the class instance (self
).forward(self, input_data)
: This method defines the forward pass of your network. It dictates how the input data (input_data
) flows through the layers and components defined in __init__
. The forward
method takes one or more input tensors and returns one or more output tensors. PyTorch's Autograd system automatically builds the computation graph based on the operations performed within this forward
method, enabling automatic differentiation.Here's a conceptual skeleton of a custom module:
import torch
import torch.nn as nn
class MySimpleNetwork(nn.Module):
def __init__(self):
super(MySimpleNetwork, self).__init__()
# Define layers or components here
# Example: a linear layer
self.layer1 = nn.Linear(in_features=10, out_features=5)
# Example: an activation function instance
self.activation = nn.ReLU()
def forward(self, x):
# Define the flow of data through the components
x = self.layer1(x)
x = self.activation(x)
return x
# Instantiate the network
model = MySimpleNetwork()
print(model)
Executing this would print a representation of the network structure, demonstrating how nn.Module
helps organize your components.
A critical feature of nn.Module
is its ability to automatically register and manage learnable parameters. When you assign an instance of a PyTorch layer (like nn.Linear
, nn.Conv2d
, etc.) as an attribute in the __init__
method, nn.Module
recognizes the internal parameters (weights and biases) of that layer.
These parameters are instances of the torch.nn.Parameter
class, which is a special subclass of torch.Tensor
. The key difference is that Parameter
objects automatically have requires_grad=True
by default, and they are registered with the parent nn.Module
. This registration allows PyTorch to easily collect all learnable parameters of a model, which is essential for passing them to an optimizer during training.
You can also define your own custom learnable parameters directly using nn.Parameter
:
class CustomModuleWithParameter(nn.Module):
def __init__(self):
super().__init__()
# A learnable parameter tensor
self.my_weight = nn.Parameter(torch.randn(5, 2))
# A regular tensor attribute (not automatically tracked for optimization)
self.my_info = torch.tensor([1.0, 2.0])
def forward(self, x):
# Example usage
return torch.matmul(x, self.my_weight)
module = CustomModuleWithParameter()
# Accessing parameters tracked by the module
for name, param in module.named_parameters():
print(f"Parameter name: {name}, Shape: {param.shape}, Requires grad: {param.requires_grad}")
Notice how my_weight
is listed as a parameter, while my_info
is not. This automatic tracking simplifies the process of managing potentially thousands or millions of parameters in deep networks.
nn.Module
FunctionalityBeyond defining structure and managing parameters, nn.Module
provides several useful methods inherited by your custom classes:
parameters()
: Returns an iterator over all nn.Parameter
objects within the module (including those in submodules). This is typically used to provide the model's parameters to an optimizer.named_parameters()
: Similar to parameters()
, but yields tuples of (parameter name, parameter object). Useful for inspecting or selectively modifying specific parameters.children()
: Returns an iterator over the immediate child modules (submodules defined as attributes).modules()
: Returns an iterator over all modules within the network, starting with the module itself and then recursively iterating through all submodules.state_dict()
: Returns a Python dictionary containing the entire state of the module, primarily mapping each parameter and buffer name to its corresponding tensor. This is fundamental for saving model weights.load_state_dict()
: Loads a state (usually from a saved file) back into the module, restoring parameters and buffers.to(device)
: Moves all the module's parameters and buffers to the specified device (e.g., 'cuda'
for GPU or 'cpu'
). This is crucial for hardware acceleration.train()
: Sets the module and its submodules to training mode. This affects layers like Dropout and BatchNorm, which behave differently during training and evaluation.eval()
: Sets the module and its submodules to evaluation mode.Understanding nn.Module
is fundamental because it establishes the standard pattern for defining any neural network architecture in PyTorch. In the following sections, we will use this base class to build networks incorporating various layers, activation functions, and loss functions.
© 2025 ApX Machine Learning