Even with the most careful planning, bugs are an inevitable part of software development, and machine learning code is no exception. PyTorch's dynamic nature, often hailed for its flexibility, can sometimes present unique debugging challenges compared to TensorFlow's graph-based execution, especially for those accustomed to TensorFlow 1.x. However, this dynamism also means you can leverage standard Python debugging tools more directly. This section outlines common issues you might encounter when developing PyTorch models and provides practical strategies and tools to identify and resolve them.
Understanding common error patterns can significantly speed up the debugging process. Here are some frequently encountered issues:
This is perhaps the most common runtime error in any tensor library.
RuntimeError
messages like "mat1 and mat2 shapes cannot be multiplied", "size mismatch, m1: [A x B], m2: [C x D]", or errors related to broadcasting.in_features
or out_features
in nn.Linear
layers.tensor.view(batch_size, -1)
).print(f"Tensor X shape: {x.shape}")
before and after operations you suspect.import pdb; pdb.set_trace()
or your IDE's debugger to pause execution and inspect tensor.shape
attributes at various points.# Assuming model is your nn.Module instance
# And you suspect an issue around a specific input size
dummy_input = torch.randn(1, 3, 224, 224) # Example for an image model
try:
output = model(dummy_input)
print(f"Dummy output shape: {output.shape}")
except Exception as e:
print(f"Error with dummy input: {e}")
Numerical instability can quickly derail training.
NaN
(Not a Number) or inf
(infinity), or model weights become NaN
/inf
. Gradients might explode to very large values or vanish to zero.torch.log(x)
where x <= 0
), division by a very small number or zero.torch.autograd.set_detect_anomaly(True)
: This is a powerful tool. Wrap your training step (forward and backward pass) in its context manager:
# At the beginning of your training script
# torch.autograd.set_detect_anomaly(True) # For older PyTorch versions
# In your training loop
# for data, target in train_loader:
# optimizer.zero_grad()
# with torch.autograd.detect_anomaly(): # Preferred way
# output = model(data)
# loss = criterion(output, target)
# loss.backward()
# optimizer.step()
This will print a stack trace pointing to the operation that first produced a NaN
or inf
in the backward pass. It adds overhead, so use it only for debugging.log(x)
is causing issues, print x
right before the log
operation.# After loss.backward() and before optimizer.step()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
param.grad
for all parameters after loss.backward()
.
for name, param in model.named_parameters():
if param.grad is not None:
print(f"Parameter: {name}, Grad mean: {param.grad.mean()}, Grad std: {param.grad.std()}")
else:
print(f"Parameter: {name}, Grad is None")
The loss stagnates, or accuracy remains at random chance levels.
optimizer.step()
or optimizer.zero_grad()
. Placing optimizer.zero_grad()
incorrectly (e.g., before loss.backward()
in some exotic scenarios, though usually it's at the start of the loop).requires_grad=False
when they should be trainable..detach()
on a tensor that should be part of the computation graph for gradient calculation.nn.BCELoss
).forward
Pass: The forward
pass might execute without Python errors but produce mathematically incorrect results.# Get a single batch
data_iter = iter(train_loader)
sample_data, sample_targets = next(data_iter)
# Train for many epochs on this single batch
for epoch in range(100): # Or more
optimizer.zero_grad()
output = model(sample_data)
loss = criterion(output, sample_targets)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {loss.item()}")
optimizer.step()
and optimizer.zero_grad()
Placement.param.requires_grad
:
for name, param in model.named_parameters():
print(f"{name}: requires_grad={param.requires_grad}")
None
for layers you expect to train. Small but non-zero gradients are okay; None
or zero gradients for all weights are problematic.nn.CrossEntropyLoss
, model output should be raw logits. For nn.BCELoss
, output should be passed through torch.sigmoid
and targets should be 0 or 1.These errors indicate problems with GPU usage.
RuntimeError: CUDA out of memory
, RuntimeError: CUDA error: an illegal memory access was encountered
, or device assertion errors.# Bad: Accumulates computation graph
# all_losses_gpu = []
# for ...:
# loss = criterion(output, target)
# all_losses_gpu.append(loss) # loss is still on GPU with graph
# Good: Stores only Python float
all_losses_scalar = []
for i in range(num_iterations): # Pseudocode
# ... forward pass ...
loss = criterion(output, target)
# ... backward pass, optimizer step ...
all_losses_scalar.append(loss.item()) # .item() gets Python number, detaches
torch.cuda.empty_cache()
: This can free up unused cached memory on the GPU. However, it doesn't free actively used memory. Use it sparingly, as it can slow down execution.tensor_cpu = tensor_gpu.cpu()
.del tensor_gpu
: Explicitly delete tensors if they are large and no longer needed. Combined with torch.cuda.empty_cache()
this can sometimes help.optimizer.step()
. This allows for a larger effective batch size without increasing memory per step.torch.utils.checkpoint
trades compute for memory.PyTorch's Pythonic nature provides access to a range of helpful debugging tools.
The humble print()
statement is often the quickest way to inspect tensor shapes, dtypes, devices, or intermediate values.
print(f"Layer X output: {output.shape}, {output.dtype}, {output.device}, mean: {output.mean().item()}")
torch.autograd.set_detect_anomaly(True)
As mentioned earlier, this context manager is invaluable for tracing NaN
or Inf
errors back to their origin in the backward pass.
with torch.autograd.detect_anomaly():
loss.backward()
pdb
or IDE Integrated Debuggers)PyTorch code is Python code. You can insert import pdb; pdb.set_trace()
anywhere to drop into the Python debugger. From there, you can inspect variables, step through code, and execute commands. Most IDEs (like VS Code, PyCharm) offer sophisticated graphical debuggers that work seamlessly with PyTorch.
Hooks allow you to attach functions to nn.Module
instances or Tensors to inspect (or modify) activations and gradients without altering the module's forward
method.
register_forward_hook
): Run after a module's forward
pass. Useful for inspecting activations or feature maps.
def print_activation_shape(module, input_tensor, output_tensor):
print(f"Module: {module.__class__.__name__}")
print(f" Input shape: {input_tensor[0].shape if isinstance(input_tensor, tuple) else input_tensor.shape}")
print(f" Output shape: {output_tensor.shape}")
# Register hook on a specific layer
model.conv1.register_forward_hook(print_activation_shape)
register_hook
): Run when gradient w.r.t. a tensor is computed. Useful for inspecting gradients of specific tensors.
def print_grad(grad):
print(f"Gradient shape: {grad.shape}, mean: {grad.mean()}")
# Assume 'x' is an input tensor that requires gradients
# x = torch.randn(10, 20, requires_grad=True)
# y = model(x)
# y.register_hook(print_grad) # Hook on y, will print grad of d(loss)/dy
# loss.backward()
Note: module.register_full_backward_hook
and module.register_backward_hook
provide access to gradients w.r.t module inputs and outputs.After loss.backward()
, inspect the .grad
attribute of parameters and intermediate tensors that have requires_grad=True
.
# After loss.backward()
for name, param in model.named_parameters():
if param.grad is None:
print(f"WARNING: No gradient for {name}")
elif torch.all(param.grad == 0):
print(f"WARNING: Gradient for {name} is all zeros")
If .grad
is None
, it means that parameter was not part of the computation graph leading to the loss, or its requires_grad
was False
. If it's all zeros, it might indicate a problem like dying ReLUs or saturated activations.
For complex models, understanding the computation graph can be helpful. Libraries like torchviz
can generate diagrams of the graph.
# pip install torchviz
from torchviz import make_dot
# Assuming 'loss' is the final output of your graph
# and model is your nn.Module
graph_viz = make_dot(loss, params=dict(model.named_parameters()))
graph_viz.render("computation_graph", format="png") # Saves a PNG file
This can help identify detached parts of the graph or incorrect connections. Below is a simplified representation of such a graph:
Data flow through a simple model, from input to loss, and the subsequent gradient propagation during the backward pass.
A structured approach is often more effective than random trial-and-error:
torch.manual_seed(0)
).git diff
to see what changed or revert to a known good state.If you're coming from TensorFlow, here are a few points to keep in mind:
tf.Session
: You don't need a session to run operations. Tensors are evaluated immediately. This makes print()
statements for tensor values work as expected without needing tf.print
or evaluating within a session.model.fit()
abstracts many details. When you write a custom training loop in PyTorch, you have more control, but also more responsibilities. Common mistakes include forgetting optimizer.zero_grad()
, loss.backward()
, or optimizer.step()
. The detailed control, however, means you can insert debugging logic anywhere in the loop..to(device)
. Errors related to tensors being on different devices (Expected all tensors to be on the same device
) are common but usually easy to fix by ensuring all inputs to an operation are on the target device.Debugging is an acquired skill, blending systematic investigation with intuition built from experience. By understanding common PyTorch issues and leveraging its debugging tools, you can efficiently identify and resolve problems in your models. Remember that the PyTorch forums and documentation are excellent resources when you encounter particularly challenging bugs.
© 2025 ApX Machine Learning