Understanding how gradients flow through your network is fundamental for debugging and optimization. When models behave unexpectedly or training stagnates, examining the gradients and the underlying computational graph often provides valuable clues. As we established earlier, PyTorch builds this graph dynamically as operations are performed on tensors that require gradients. Let's look at techniques to examine these gradients and visualize the graph structure.
After you call loss.backward()
, PyTorch computes the gradients of the loss with respect to all tensors in the computational graph that have requires_grad=True
and were involved in computing the loss. These gradients are accumulated in the .grad
attribute of the respective leaf tensors (typically model parameters or inputs).
import torch
# Example setup
w = torch.randn(5, 3, requires_grad=True)
x = torch.randn(3, 2)
y_true = torch.randn(5, 2)
# Forward pass
y_pred = w @ x
loss = torch.nn.functional.mse_loss(y_pred, y_true)
# Backward pass
loss.backward()
# Inspect the gradient accumulated in w
print("Gradient for w:\n", w.grad)
# Gradients for non-leaf tensors or tensors with requires_grad=False are usually None
print("Gradient for x:", x.grad) # Output: None (requires_grad=False by default)
print("Gradient for y_pred:", y_pred.grad) # Output: None (non-leaf, gradients not retained by default)
Common situations you'll encounter when inspecting .grad
:
None
Gradients: If a tensor's .grad
is None
after .backward()
, it usually means:
requires_grad=True
.with torch.no_grad():
block or detached using .detach()
).tensor.retain_grad()
if you need to inspect gradients of intermediate results.NaN
). This leads to unstable training, large weight updates, and often results in NaN
values appearing in the loss or weights. Gradient clipping (covered in Chapter 3) is a common mitigation strategy.You can programmatically check for these issues:
# Check for None gradients (assuming 'model' is your torch.nn.Module)
for name, param in model.named_parameters():
if param.grad is None:
print(f"Parameter {name} has no gradient.")
# Check for vanishing/exploding gradients
max_grad_norm = 0.0
min_grad_norm = float('inf')
nan_detected = False
for param in model.parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
if torch.isnan(param.grad).any():
nan_detected = True
print(f"NaN gradient detected in parameter: {param.size()}") # More specific identification might be needed
max_grad_norm = max(max_grad_norm, grad_norm)
min_grad_norm = min(min_grad_norm, grad_norm)
print(f"Max gradient norm: {max_grad_norm:.4e}")
print(f"Min gradient norm: {min_grad_norm:.4e}")
if nan_detected:
print("Warning: NaN gradients detected!")
For more detailed analysis during the backward pass, PyTorch provides hooks. Hooks are functions that can be registered to execute when a specific event occurs, such as gradient computation for a tensor or the forward/backward pass of a module.
register_hook
)You can register a hook directly on a tensor. This hook function will be executed when the gradient for that specific tensor is computed. The hook function receives the gradient as its only argument.
import torch
def print_grad_hook(grad):
print(f"Gradient received: shape={grad.shape}, norm={grad.norm():.4f}")
x = torch.randn(3, 3, requires_grad=True)
y = x.pow(2).sum()
# Register the hook on tensor x
hook_handle = x.register_hook(print_grad_hook)
# Compute gradients
y.backward()
# The hook function (print_grad_hook) is called automatically
# Output will include something like:
# Gradient received: shape=torch.Size([3, 3]), norm=9.5930
# Hooks should be removed when no longer needed to avoid memory leaks
hook_handle.remove()
# You can also modify gradients within a hook, but use with caution:
def scale_grad_hook(grad):
# Example: Halve the gradient
return grad * 0.5
# x.register_hook(scale_grad_hook)
# y.backward() # Now the gradient stored in x.grad will be halved
Hooks are powerful for debugging specific parts of your network. You can log gradient statistics, check for NaN
values precisely when they occur, or even modify gradients on the fly (though modifying gradients is generally less common and requires careful consideration).
You can also register hooks on torch.nn.Module
instances to inspect inputs and outputs during the forward pass or gradients during the backward pass.
register_forward_pre_hook(hook)
: Executes before the module's forward
method. Receives (module, input)
.register_forward_hook(hook)
: Executes after the module's forward
method. Receives (module, input, output)
.register_full_backward_hook(hook)
: Executes after gradients have been computed for the module's inputs and outputs. Receives (module, grad_input, grad_output)
. grad_input
is a tuple of gradients with respect to the module's inputs, and grad_output
is a tuple of gradients with respect to the module's outputs.import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(5, 1)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
model = SimpleNet()
input_tensor = torch.randn(4, 10, requires_grad=True)
def backward_hook(module, grad_input, grad_output):
print(f"\nModule: {module.__class__.__name__}")
print(" grad_input shapes: ", [g.shape if g is not None else None for g in grad_input])
print(" grad_output shapes:", [g.shape if g is not None else None for g in grad_output])
# Register hook on linear2 layer
hook_handle_bwd = model.linear2.register_full_backward_hook(backward_hook)
# Forward and backward pass
output = model(input_tensor)
target = torch.randn(4, 1)
loss = nn.functional.mse_loss(output, target)
loss.backward()
# Output will show gradient shapes flowing backward through linear2
# Module: Linear
# grad_input shapes: [torch.Size([4, 5]), torch.Size([5]), None] (Input, Weight, Bias) Bias grad might be None if bias=False
# grad_output shapes: [torch.Size([4, 1])]
hook_handle_bwd.remove() # Clean up
Module hooks are particularly useful for understanding how gradients propagate layer by layer or for diagnosing issues within specific modules of a larger network.
While hooks let you inspect gradients numerically, visualizing the computational graph provides a structural overview. This helps understand dependencies between operations and parameters, confirming your model architecture, or finding unexpected connections.
torchviz
A popular third-party library for basic graph visualization is torchviz
. It uses the graphviz library to render the graph generated during the backward pass.
You typically call torchviz.make_dot
on the output tensor (often the loss) whose gradient computation graph you want to visualize. It returns a graphviz.Digraph
object.
# Requires: pip install torchviz graphviz
import torch
import torchviz
# Simple example
a = torch.tensor([2.0], requires_grad=True)
b = torch.tensor([3.0], requires_grad=True)
c = a * b
d = c + a
L = d.mean() # Final scalar output
# Generate the graph visualization object
# params can be used to highlight specific parameters
graph = torchviz.make_dot(L, params={'a': a, 'b': b})
# To view the graph, you can render it to a file or display it in environments like Jupyter
# graph.render("computation_graph", format="png") # Saves graph.png
# display(graph) # In Jupyter environments
# For demonstration, let's print the Graphviz source
# print(graph.source)
A simple computational graph generated by
torchviz
. Ellipses represent tensors (parameters highlighted), and boxes represent the backward operations (grad_fn
). Arrows show the flow of gradients during the backward pass.
torchviz
provides a clear, high-level view of the backward graph, excellent for understanding dependencies and the flow for gradient computation.
PyTorch has built-in support for TensorBoard, a powerful visualization toolkit from TensorFlow. You can log the computational graph (and many other things like scalars, images, histograms) using torch.utils.tensorboard.SummaryWriter
.
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
# Define a simple model again
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(5, 3)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(3, 1)
def forward(self, x):
return self.layer2(self.relu(self.layer1(x)))
model = SimpleNet()
dummy_input = torch.randn(1, 5) # Provide a sample input
# Create a SummaryWriter instance (logs to ./runs/ by default)
writer = SummaryWriter('runs/graph_demo')
# Add the graph to TensorBoard
# The writer requires the model and a sample input tensor
writer.add_graph(model, dummy_input)
writer.close()
# To view the graph:
# 1. Ensure you have tensorboard installed (pip install tensorboard)
# 2. Run `tensorboard --logdir=runs/graph_demo` in your terminal
# 3. Open the URL provided (usually http://localhost:6006/) in your browser
# 4. Navigate to the "Graphs" tab.
TensorBoard provides an interactive graph visualization environment directly in your browser. It often displays a more detailed graph, including module scopes, parameter nodes, and operation nodes. While potentially overwhelming for very large models, its interactivity allows you to expand and collapse parts of the graph, making it easier to navigate complex architectures compared to a static image.
if
statements affecting layers used), the graph might change between iterations or different inputs.Effectively inspecting gradients and visualizing the computational graph are indispensable skills for advanced PyTorch development. They move beyond treating the framework as a black box, enabling deeper understanding, targeted debugging, and informed optimization decisions. The next chapter will build upon this foundation as we explore how to implement complex network architectures.
© 2025 ApX Machine Learning