Once you've defined your neural network architecture by subclassing torch.nn.Module
, as discussed in previous sections, the next step is often to interact with its components. You might want to inspect its structure, examine the initial (or learned) weights, modify parameters for fine-tuning, or even swap out entire layers. If you're coming from TensorFlow and Keras, you're likely familiar with tools like model.summary()
, accessing model.layers
, and using methods like get_weights()
or set_weights()
. PyTorch offers a more direct, Python-native way to achieve these tasks.
PyTorch models, being standard Python classes, can be inspected in several intuitive ways. The most straightforward method is to simply print the model instance. This provides a hierarchical view of the modules and sub-modules within your network.
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 10, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.classifier = nn.Linear(10 * 12 * 12, 50) # Assuming 28x28 input, (28-5+1)/2 = 12
self.output = nn.Linear(50, 10)
def forward(self, x):
x = self.features(x)
x = x.view(-1, 10 * 12 * 12) # Flatten
x = torch.relu(self.classifier(x))
x = self.output(x)
return x
model = SimpleNet()
print(model)
This output shows the layers you've defined, their types, and any arguments passed during their instantiation. It's similar to Keras's model.summary()
in terms of showing the architecture, though model.summary()
typically includes parameter counts and output shapes per layer, which you'd compute separately in PyTorch if needed.
For more programmatic access to the layers (which are themselves nn.Module
instances), PyTorch provides several iterators:
model.children()
: This yields an iterator over the direct child modules of the model. For SimpleNet
above, model.children()
would yield the nn.Sequential
block assigned to self.features
, the nn.Linear
layer self.classifier
, and the nn.Linear
layer self.output
.model.modules()
: This yields an iterator over all modules in the network in a recursive fashion, including the model itself as the first item, then its children, then their children, and so on.model.named_children()
and model.named_modules()
: These are often more useful as they yield tuples of (name, module)
, where name
is the attribute name you assigned in __init__
(e.g., 'features', 'classifier') or an index for layers within nn.Sequential
.print("Children of the model:")
for name, module in model.named_children():
print(f"Name: {name}, Type: {type(module)}")
print("\nAll modules in the model:")
for name, module in model.named_modules():
print(f"Path: {name}, Type: {type(module)}")
This detailed introspection is valuable for debugging, understanding complex architectures, or selectively applying operations to certain parts of your network.
Hierarchy of modules in the
SimpleNet
example. The notes illustrate hownamed_parameters()
andnamed_modules()
would reference elements within this structure.
In PyTorch, the learnable parameters of a model (weights and biases) are instances of torch.Tensor
that have their requires_grad
attribute set to True
. torch.nn.Module
provides convenient ways to access these parameters:
model.parameters()
: Returns an iterator over all parameters of the model.model.named_parameters()
: Returns an iterator yielding tuples of (name, parameter)
. The name is a string indicating the path to the parameter, like features.0.weight
(weight of the first layer in the features
Sequential block) or classifier.bias
.Let's examine the parameters of our SimpleNet
:
print("\nModel Parameters (name, shape, requires_grad):")
for name, param in model.named_parameters():
print(f"Name: {name}, Shape: {param.shape}, Requires Grad: {param.requires_grad}")
# Access a specific parameter
# For example, the weights of the 'classifier' layer
classifier_weights = model.classifier.weight
print(f"\nClassifier layer weights shape: {classifier_weights.shape}")
print(f"Classifier layer weights data (first 5 values of first row):\n {classifier_weights.data[0, :5]}")
This is different from Keras, where layer.get_weights()
returns a list of NumPy arrays. In PyTorch, you directly get the torch.Tensor
objects. The .data
attribute of a parameter tensor gives you direct access to the underlying data, bypassing the gradient tracking system for direct manipulation if needed.
Another important concept is the state_dict
. A module's state_dict
is a Python dictionary object that maps each layer to its parameter tensor. For parameters, the keys are the parameter names (e.g., features.0.weight
), and the values are the tensors themselves. The state_dict
is fundamental for saving and loading model checkpoints, a topic we'll cover in detail later.
# Get the state dictionary
state_dict = model.state_dict()
print("\nKeys in model.state_dict():")
for key in state_dict.keys():
print(key)
# Example: Accessing the bias of the 'output' layer from state_dict
output_bias_from_state_dict = state_dict['output.bias']
print(f"\nOutput layer bias from state_dict: {output_bias_from_state_dict}")
There are several scenarios where you might want to modify model parameters: custom initialization, loading pretrained weights into a part of your model (transfer learning), or freezing layers during training.
You can modify parameters in-place using their .data
attribute. This is useful for operations like setting all biases to zero or applying a custom initialization scheme after model creation.
# Example: Set all biases in the 'classifier' layer to zero
print(f"\nClassifier bias before: {model.classifier.bias.data[:5]}")
with torch.no_grad(): # Important for in-place modifications of parameters
model.classifier.bias.data.fill_(0)
print(f"Classifier bias after: {model.classifier.bias.data[:5]}")
The with torch.no_grad():
context manager is important here. It temporarily disables gradient calculation, which is necessary when you're manually changing parameter values that requires_grad=True
. Modifying .data
directly is an older pattern; a more common way to ensure no gradient side-effects is to operate within torch.no_grad()
.
requires_grad
)For transfer learning, a common technique is to freeze the weights of pretrained layers and only train the newly added parts of the model. In PyTorch, this is achieved by setting the requires_grad
attribute of the parameters you want to freeze to False
.
# Freeze all parameters in the 'features' block
print(f"\nBefore freezing, features.0.weight.requires_grad: {model.features[0].weight.requires_grad}")
for param in model.features.parameters():
param.requires_grad = False
print(f"After freezing, features.0.weight.requires_grad: {model.features[0].weight.requires_grad}")
# Verify only classifier and output layers' parameters require gradients now
print("\nParameters requiring gradients after freezing 'features':")
for name, param in model.named_parameters():
if param.requires_grad:
print(name)
This is analogous to setting layer.trainable = False
in Keras. When you pass these parameters to an optimizer, only those with requires_grad=True
will have their gradients computed and updated during backpropagation.
Since layers in PyTorch models are attributes of the parent module (or elements in a container like nn.Sequential
), you can access and modify them using standard Python attribute access or indexing.
As seen before, if a layer is defined as self.mylayer = nn.Linear(...)
in __init__
, you can access it via model.mylayer
. For layers within an nn.Sequential
container, you can use integer indexing:
# Access the first convolutional layer in the 'features' block
first_conv_layer = model.features[0]
print(f"\nFirst conv layer in features: {first_conv_layer}")
Replacing a layer is as simple as assigning a new module to the attribute. This is particularly powerful and showcases PyTorch's flexibility.
# Example: Replace the 'output' layer with a new one having different output features
print(f"\nOriginal output layer: {model.output}")
old_output_layer_weights_shape = model.output.weight.shape
# Let's say we want to change to 20 output classes
model.output = nn.Linear(model.output.in_features, 20)
# New weights will be randomly initialized by default for the new layer
print(f"New output layer: {model.output}")
print(f"Old output layer weights shape: {old_output_layer_weights_shape}")
print(f"New output layer weights shape: {model.output.weight.shape}")
# Ensure the new layer's parameters are trainable (they are by default)
print(f"New output layer weight requires_grad: {model.output.weight.requires_grad}")
This dynamic modification is a significant advantage of PyTorch. If you had used Keras's Sequential API, replacing a layer often means rebuilding the model from that point onwards or creating a new model. While Keras's Functional API offers more flexibility, PyTorch's approach feels very natural for Python developers.
When modifying model architecture, such as replacing a layer, remember that the optimizer might need to be updated if it was already initialized. Specifically, if you replace a layer, the id
s of its parameters will change. Optimizers typically store references to the parameters they are optimizing. If you create an optimizer, then replace a layer, the optimizer will still hold references to the old layer's parameters. It's common practice to create the optimizer after all model modifications are complete, or to re-initialize it with the new set of model.parameters()
.
# (Re)-initialize optimizer if model structure changed and optimizer was already created
# For example:
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
Being able to easily access, inspect, and modify both parameters and entire layers provides a fine-grained level of control over your PyTorch models. This is invaluable for advanced techniques like surgical fine-tuning, model surgery for research, or adapting pretrained models to new tasks. This contrasts with TensorFlow Keras where, while possible, such manipulations often feel less direct and might require a deeper understanding of the underlying graph construction, especially in TF1.x. In PyTorch, it's just Python objects and attributes.
© 2025 ApX Machine Learning