While tools like TensorBoard help monitor training trends and visualizing network graphs provides architectural insight, sometimes you need to dig into the exact state of your program at a specific point in execution. Shape mismatches might occur deep within a model's forward
pass, gradients might unexpectedly become NaN
, or tensor values might diverge inexplicably. For these situations, stepping through your code line by line with a debugger is often the most direct way to find the root cause.
Python's built-in debugger, pdb
, is a powerful, text-based tool that works seamlessly with PyTorch code. It allows you to pause execution, inspect variables (including tensors), execute code step by step, and understand the program's flow precisely when and where issues arise.
The most common way to start a debugging session with pdb
is to insert the following two lines directly into your Python script at the location where you want execution to pause:
import pdb
pdb.set_trace()
When the Python interpreter encounters pdb.set_trace()
, it stops execution and drops you into the pdb
interactive console right in your terminal. The (Pdb)
prompt indicates you are now in the debugger.
Where should you place pdb.set_trace()
?
pdb.set_trace()
on the line immediately before it. This allows you to inspect the inputs to that operation.forward
method: To understand how data transforms as it passes through layers, place the trace inside the forward
method. You can step through layer applications and check tensor shapes and values.if batch_idx == problematic_index: import pdb; pdb.set_trace()
).loss.backward()
: To inspect gradients after they've been computed but before the optimizer step, place the trace after the backward()
call.Alternatively, you can launch your entire script under pdb
control from the command line, which starts the debugger at the very first line:
python -m pdb your_pytorch_script.py
This is useful for diagnosing issues that happen very early in the script's execution, like import errors or setup problems.
Once you're at the (Pdb)
prompt, you can use various commands to control execution and inspect the state. Here are some of the most useful ones:
n
(next): Execute the current line and stop at the next line in the current function. If the current line is a function call, n
executes the entire function and stops after it returns.s
(step): Similar to n
, but if the current line is a function call, s
steps into the function, stopping at its first line.c
(continue): Resume normal execution until the next breakpoint (or pdb.set_trace()
call) is encountered, or until the script finishes or errors out.l
(list): Show the source code around the current line of execution. Use l .
to list code centered on the current line again.p <expression>
(print): Evaluate the <expression>
in the current context and print its value. This is arguably the most important command for debugging PyTorch. You can inspect tensors, variables, model parameters, etc.
p my_tensor.shape
p my_tensor.dtype
p my_tensor.device
p my_tensor
(prints the tensor itself; can be large)p model.layer1.weight.grad
(after backward()
)p loss.item()
a
(args): Print the argument list of the current function.r
(return): Continue execution until the current function returns.b <line_number>
(breakpoint): Set a breakpoint at a specific <line_number>
in the current file. Execution will pause when it reaches this line. You can also specify breakpoints in other files (b path/to/file.py:<line_number>
) or on methods (b self.my_method
).cl
or clear
: Clear all breakpoints. cl <bp_number>
clears a specific breakpoint.q
(quit): Exit the debugger and terminate the script immediately.h
(help): Display a list of available commands. h <command>
provides help on a specific command.Let's see how pdb
helps with common PyTorch debugging scenarios.
Scenario 1: Debugging Shape Mismatches in a Model
Imagine you have a simple model and are getting a shape mismatch error during the forward pass.
import torch
import torch.nn as nn
import pdb
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 20)
self.activation = nn.ReLU()
# Potential mistake: Input size doesn't match layer1 output
self.layer2 = nn.Linear(25, 5) # ERROR is likely here (25 != 20)
def forward(self, x):
print(f"Initial shape: {x.shape}")
x = self.layer1(x)
print(f"After layer1: {x.shape}")
x = self.activation(x)
print(f"After activation: {x.shape}")
# Let's debug before layer2
pdb.set_trace()
# This line will likely cause a runtime error
x = self.layer2(x)
print(f"After layer2: {x.shape}")
return x
# Example usage
net = SimpleNet()
# Create a dummy input tensor
input_tensor = torch.randn(32, 10) # Batch of 32, features=10
output = net(input_tensor)
When you run this code, it will print the shapes and then stop at pdb.set_trace()
.
Initial shape: torch.Size([32, 10])
After layer1: torch.Size([32, 20])
After activation: torch.Size([32, 20])
-> x = self.layer2(x)
(Pdb)
At the (Pdb)
prompt, you can inspect:
p x.shape
: This will output torch.Size([32, 20])
.p self.layer2
: This will show the definition Linear(in_features=25, out_features=5, bias=True)
.p self.layer2.in_features
: This will output 25
.By comparing the input shape [32, 20]
(specifically the feature dimension 20
) with layer2.in_features
(25
), the mismatch becomes obvious. You can then quit (q
) and fix the nn.Linear
definition.
Scenario 2: Inspecting Gradients
Let's say your loss isn't decreasing, and you suspect vanishing or exploding gradients. You can check them after the backward pass.
# Assume model, data (inputs, targets), loss_fn, optimizer are defined
# Forward pass
outputs = model(inputs)
loss = loss_fn(outputs, targets)
# Backward pass
optimizer.zero_grad()
loss.backward()
# Insert debugger here to inspect gradients
import pdb
pdb.set_trace()
# Optimizer step (would happen after debugging)
# optimizer.step()
When execution pauses, you can check gradients for specific parameters:
p model.some_layer.weight.grad
: Inspect the gradient tensor for the weight of some_layer
. Look for NaN
values, very large values (explosion), or very small values (vanishing).p model.some_layer.weight.grad.abs().mean()
: Calculate the mean absolute gradient to get a sense of the magnitude.p loss.item()
: Remind yourself of the current loss value.pdb.set_trace()
as close as possible to where you suspect the problem lies.p
): Leverage the p
command extensively to check tensor shapes, data types, device placement, and actual values.n
(next) to move line by line. Use s
(step) only when you need to dive into a function call you wrote (stepping into PyTorch's internal functions can be verbose).import pdb; pdb.set_trace()
calls before finalizing your code, especially before committing to version control or deploying.pdb
halts execution. For debugging issues that only appear after many hours of training, interactive debugging might be impractical. In such cases, logging detailed information (tensor shapes, loss values, gradient norms) periodically might be more suitable, perhaps combined with conditional breakpoints or assertion checks.Using pdb
effectively takes a little practice, but it's an indispensable tool for understanding the detailed, step-by-step execution of your PyTorch code and resolving many common bugs that aren't immediately obvious from stack traces or high-level metrics.
© 2025 ApX Machine Learning