After you've saved and loaded a PyTorch model, or when you're working with a pre-trained model, a common next step is to inspect its internal structure and the learned parameters. This is important for several reasons: verifying that the model loaded correctly, understanding the architecture before fine-tuning, debugging unexpected behavior, or simply learning how a particular network is constructed. If you're coming from TensorFlow, you'll find that PyTorch offers equally powerful, albeit different, mechanisms for model inspection.
print()
The most straightforward way to get an overview of your model's architecture in PyTorch is to simply print the model object. This command iterates through the modules defined in your model's constructor (__init__
) and prints their structure.
Let's define a simple Convolutional Neural Network (CNN) for illustration. Assume this network is intended for an MNIST-like dataset with 1-channel images of size 28x28.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2)
# Fully connected layers
# After two 5x5 convolutions and 2x2 max pooling operations on a 28x28 image:
# Output size from conv1: (28-5+1)/1 = 24x24. After pool1: 12x12.
# Output size from conv2: (12-5+1)/1 = 8x8. After pool2: 4x4.
# So, flattened size = 20 * 4 * 4 = 320
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10) # Output for 10 classes
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = x.view(-1, 320) # Flatten the tensor
x = F.relu(self.fc1(x))
x = self.fc2(x)
x = self.log_softmax(x)
return x
# Instantiate the model
model = SimpleNet()
print(model)
The output will look something like this:
SimpleNet(
(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
(relu1): ReLU()
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
(relu2): ReLU()
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): Linear(in_features=320, out_features=50, bias=True)
(fc2): Linear(in_features=50, out_features=10, bias=True)
(log_softmax): LogSoftmax(dim=1)
)
This output lists each layer defined as an attribute in SimpleNet
, along with its type (e.g., Conv2d
, Linear
, ReLU
) and the parameters used to initialize it (e.g., in_channels
, out_channels
, kernel_size
for Conv2d
).
For TensorFlow developers, this is somewhat similar to Keras's model.summary()
method. However, the print(model)
output in PyTorch is a direct representation of the nn.Module
hierarchy. It doesn't typically include layer-wise output shapes or a detailed parameter count table in the same way model.summary()
does, though it provides a clear view of the model's components and their configurations.
For a more programmatic way to access your model's layers, PyTorch offers several iterators.
model.children()
and model.named_children()
If you want to iterate only over the direct children modules of your model (i.e., those assigned as attributes in its __init__
), you can use model.children()
or model.named_children()
. The latter also provides the name you assigned to the attribute.
print("Direct children of the model:")
for name, module in model.named_children():
print(f"Name: {name}, Module: {module}")
This will list modules like conv1
, relu1
, pool1
, etc., which are direct attributes of SimpleNet
.
model.modules()
and model.named_modules()
To iterate over all modules in the network recursively (including nested modules if you had, for example, an nn.Sequential
block as a child), use model.modules()
or model.named_modules()
.
print("\nAll modules in the model (recursive):")
for name, module in model.named_modules():
# The top-level model itself is also included with an empty name
if name: # Filter out the top-level model itself for cleaner output here
print(f"Name: {name}, Module Type: {type(module).__name__}")
This recursive iteration is useful for accessing every single part of your network, including those within container modules.
The following diagram illustrates the architecture of our SimpleNet
model, showing how data flows through its constituent layers and operations.
Data flow and layer organization within the
SimpleNet
model. Blue boxes represent learnable layers or fixed operation modules, greencds
shapes represent functional operations or tensor manipulations, and gray ellipses denote input/output.
Beyond the structure, you'll often need to inspect the actual learnable parameters (weights and biases) of your model.
model.named_parameters()
The model.named_parameters()
iterator is very useful for this. It yields tuples of (parameter_name, parameter_tensor).
print("\nModel Parameters (Name, Size, Requires Grad):")
for name, param in model.named_parameters():
print(f"Parameter: {name}, Size: {param.size()}, Requires Grad: {param.requires_grad}, Dtype: {param.dtype}")
# To see the actual values (be cautious with large tensors):
# print(f" First few values: {param.data.flatten()[:3].tolist()}")
The output will show names like conv1.weight
, conv1.bias
, fc1.weight
, fc1.bias
, etc., along with their tensor shapes (param.size()
) and whether they track gradients (param.requires_grad
). You can access the actual tensor data using param.data
.
This is comparable to iterating through model.layers
in Keras and calling layer.get_weights()
, which returns a list of NumPy arrays (typically weights and then biases). In PyTorch, param.data
gives you direct access to the torch.Tensor
object itself.
model.state_dict()
As discussed in the context of saving and loading models, model.state_dict()
returns an ordered dictionary containing all parameters (weights and biases) and persistent buffers (like running means in batch normalization). While primarily used for persistence, it's also a convenient way to view all parameter names and their corresponding tensors.
# state_dictionary = model.state_dict()
# for param_name in state_dictionary:
# print(f"{param_name}\t{state_dictionary[param_name].size()}")
If your layers are defined as attributes of your model class (e.g., self.conv1 = nn.Conv2d(...)
), you can access them directly by their attribute names. Once you have a specific layer object, you can inspect its parameters like weight
and bias
.
# Accessing the conv1 layer
conv1_layer = model.conv1
print(f"\nconv1 layer: {conv1_layer}")
# Accessing the weight parameter of conv1
conv1_weight = model.conv1.weight
print(f"conv1 weight_tensor_shape: {conv1_weight.size()}")
# print(f"conv1 weight values (first element): {conv1_weight[0,0,0,0].item()}") # Example for one value
# Accessing the bias parameter of fc1
if model.fc1.bias is not None:
fc1_bias_shape = model.fc1.bias.size()
print(f"fc1 bias_tensor_shape: {fc1_bias_shape}")
else:
print("fc1 layer has no bias parameter.")
This direct access is very powerful for targeted inspection or even modification of specific parts of your model.
A common use case for inspection is to verify that a model's state has been correctly loaded from a checkpoint. Let's simulate this:
import os
# 1. Save the state_dict of our current model
torch.save(model.state_dict(), "simplenet_checkpoint.pth")
# 2. Create a new instance of the model
model_reloaded = SimpleNet()
# Optional: Check a parameter before loading (it will be randomly initialized)
# print(f"Parameter conv1.weight (first value) BEFORE loading: {model_reloaded.conv1.weight.data[0,0,0,0].item()}")
# 3. Load the saved state_dict
model_reloaded.load_state_dict(torch.load("simplenet_checkpoint.pth"))
model_reloaded.eval() # Set to evaluation mode if applicable
# 4. Inspect to verify
# Compare a specific parameter from the original and reloaded model
original_conv1_weight_val = model.conv1.weight.data[0,0,0,0].item()
reloaded_conv1_weight_val = model_reloaded.conv1.weight.data[0,0,0,0].item()
print(f"\nOriginal model conv1.weight (first value): {original_conv1_weight_val}")
print(f"Reloaded model conv1.weight (first value): {reloaded_conv1_weight_val}")
if original_conv1_weight_val == reloaded_conv1_weight_val:
print("Parameter verification successful: The reloaded model's weight matches the original.")
else:
print("Parameter verification FAILED: Weights do not match.")
# Clean up the created file
os.remove("simplenet_checkpoint.pth")
This workflow ensures that your model persistence mechanism is working as expected and that the loaded model accurately reflects the saved state.
By using these inspection techniques, you can gain a clear understanding of your PyTorch model's architecture and its learned parameters. This ability is fundamental for effective model development, debugging, and deployment, allowing you to confidently adapt your TensorFlow experience to the PyTorch environment.
© 2025 ApX Machine Learning