Okay, you've set up your tensors, marked the ones you want gradients for with requires_grad=True
, and performed operations that PyTorch has dutifully tracked in its dynamic computation graph. Now, how do you actually get those gradients calculated? This is where the backward()
method comes into play.
The backward()
method is the engine that drives automatic differentiation in PyTorch. When called on a tensor, typically the final scalar loss value of your model, it initiates the computation of gradients throughout the computation graph using the chain rule. It calculates the gradient of the tensor it's called on with respect to all the "leaf" tensors in the graph that have requires_grad=True
(these are often your model's parameters or initial inputs you need gradients for).
You almost always call backward()
on a scalar tensor, which is usually the result of your loss function computation. For example, if loss
contains the single numerical value representing your model's error for a batch:
import torch
# Example setup (imagine these are results from a model)
x = torch.tensor(2.0, requires_grad=True)
w = torch.tensor(3.0, requires_grad=True)
b = torch.tensor(1.0, requires_grad=True)
# Perform some operations (building the graph)
y = w * x + b # y = 3*2 + 1 = 7
loss = y * y # loss = 7*7 = 49 (a scalar)
# Before backward pass, gradients are None
print(f"Gradient for x before backward: {x.grad}")
print(f"Gradient for w before backward: {w.grad}")
print(f"Gradient for b before backward: {b.grad}")
# Compute gradients
loss.backward()
# After backward pass, gradients are populated
print(f"Gradient for x after backward: {x.grad}") # d(loss)/dx = d(y^2)/dx = 2*y*(dy/dx) = 2*y*w = 2*7*3 = 42
print(f"Gradient for w after backward: {w.grad}") # d(loss)/dw = d(y^2)/dw = 2*y*(dy/dw) = 2*y*x = 2*7*2 = 28
print(f"Gradient for b after backward: {b.grad}") # d(loss)/db = d(y^2)/db = 2*y*(dy/db) = 2*y*1 = 2*7*1 = 14
Output: Gradient for x before backward: None Gradient for w before backward: None Gradient for b before backward: None Gradient for x after backward: 42.0 Gradient for w after backward: 28.0 Gradient for b after backward: 14.0
As you can see, calling loss.backward()
calculated the gradients ∂x∂loss, ∂w∂loss, and ∂b∂loss and stored them in the respective .grad
attributes of the x
, w
, and b
tensors.
.backward()
on a Scalar?Autograd is designed to compute the Jacobian-vector product (JVP). When you call backward()
on a scalar tensor L, it's implicitly equivalent to calling backward()
with a starting gradient of 1.0. This allows PyTorch to compute the gradients ∂p∂L for all parameters p efficiently using the chain rule propagating backward from the scalar loss.
If you try to call .backward()
on a non-scalar tensor (a tensor with more than one element), PyTorch doesn't know implicitly how to weight the gradients for each element in that tensor with respect to the final (unseen) scalar loss. You'll get a runtime error asking for a gradient
argument:
# Continuing the previous example
# y is not a scalar (it's tensor(7.))
try:
y.backward() # This will cause an error
except RuntimeError as e:
print(f"Error calling backward() on non-scalar: {e}")
# To make it work, provide a gradient tensor matching y's shape
# This represents the gradient of some final loss w.r.t y.
# Often used in more advanced scenarios.
# For demonstration, let's use torch.ones_like(y)
grad_tensor = torch.tensor(1.0) # Since y is effectively scalar here
y.backward(gradient=grad_tensor)
print(f"Gradient for x after y.backward(gradient=...): {x.grad}") # Note: gradients accumulate! Now 42 + 3 = 45
Output: Error calling backward() on non-scalar: grad can be implicitly created only for scalar outputs Gradient for x after y.backward(gradient=...): 45.0
In most standard training loops, you'll compute a single scalar loss value representing the error for a batch or sample, and you'll call loss.backward()
directly on that scalar without needing to provide the gradient
argument.
Conceptually, loss.backward()
triggers a traversal back through the graph of operations that created loss
.
A simplified computation graph showing inputs
x
,w
,b
, intermediate resulty
, and final scalarloss
. The dashed red arrows illustrate the path taken duringloss.backward()
to compute gradients with respect tox
,w
, andb
.
By default, PyTorch frees the intermediate buffers of the computation graph after backward()
is called to save memory. This means if you need to call backward()
multiple times on the same part of the graph (less common, often needed for advanced techniques or debugging), you would need to pass retain_graph=True
to the first backward()
call. However, for standard training, you construct a graph, compute loss, call backward()
, update weights, and then repeat the process for the next batch, which builds a new graph.
Understanding backward()
is fundamental to training models in PyTorch. It's the mechanism that connects your model's output and loss function back to the parameters that need adjustment. In the next sections, we'll see how these computed gradients are accessed and used by optimizers.
© 2025 ApX Machine Learning