At the heart of PyTorch's automatic differentiation capability lies the computational graph. This isn't a static structure you define upfront; instead, PyTorch constructs it dynamically as you perform operations on tensors. Think of it as a directed acyclic graph (DAG) where nodes represent either tensors or operations, and edges represent the flow of data and functional dependencies.
Understanding this graph is fundamental because it's precisely what the autograd
engine traverses during the backward pass to compute gradients using the chain rule. Every operation involving tensors that track gradients contributes to building this graph structure behind the scenes.
Frameworks like TensorFlow 1.x or Theano utilized static computational graphs. In these systems, you would first define the entire graph structure, compile it, and then execute it, potentially multiple times, with different input data. This "define-and-run" approach allows for significant graph-level optimizations before execution.
PyTorch, conversely, employs a dynamic computational graph approach, often termed "define-by-run". The graph is built implicitly, operation by operation, as your Python code executes. If you have a loop or conditional statement (like an if
block) in your model's forward pass, the graph structure can actually change from one iteration to the next based on the execution path taken.
Advantages of Dynamic Graphs:
pdb
or print statements) directly within your model's execution flow to inspect intermediate values or graph connectivity.Trade-offs:
While highly flexible, the define-by-run nature might present challenges for certain whole-graph optimizations that are simpler in static graph environments. However, PyTorch compensates for this with tools like TorchScript (covered in Chapter 4) which allow for graph capture and optimization.
grad_fn
AttributeHow does PyTorch actually track the operations to build this graph? When you perform an operation on a tensor that has requires_grad=True
, the resulting output tensor automatically gains a reference to the function that created it. This reference is stored in the output tensor's grad_fn
attribute.
Let's illustrate with a simple example:
import torch
# Input tensor requiring gradients
a = torch.tensor([2.0, 3.0], requires_grad=True)
# Operation 1: Multiply by 3
b = a * 3
# Operation 2: Calculate the mean
c = b.mean()
# Inspect the grad_fn attributes
print(f"Tensor a: requires_grad={a.requires_grad}, grad_fn={a.grad_fn}")
# Expected output: Tensor a: requires_grad=True, grad_fn=None
print(f"Tensor b: requires_grad={b.requires_grad}, grad_fn={b.grad_fn}")
# Expected output: Tensor b: requires_grad=True, grad_fn=<MulBackward0 object at 0x...>
print(f"Tensor c: requires_grad={c.requires_grad}, grad_fn={c.grad_fn}")
# Expected output: Tensor c: requires_grad=True, grad_fn=<MeanBackward0 object at 0x...>
Notice the following:
a
is a leaf node in the graph. It was created directly by the user, not as the result of an operation tracked by autograd. Therefore, its grad_fn
is None
.b
resulted from multiplying a
by 3. Its grad_fn
points to MulBackward0
, representing the multiplication operation. This object holds references back to the inputs of the multiplication (tensor a
and the scalar 3) and knows how to compute the gradient with respect to a
.c
resulted from the mean
operation on b
. Its grad_fn
points to MeanBackward0
, which knows how to compute the gradient with respect to its input, b
.These grad_fn
references form a linked list tracing backwards from the output tensor (c
) through the operations (MeanBackward0
, MulBackward0
) to the input leaf tensor (a
). This linked structure is the backward computational graph that autograd
uses.
While PyTorch doesn't provide a built-in, real-time graph visualizer like TensorBoard's graph view for static graphs, we can conceptualize the graph built by the previous example. The forward pass creates tensors and associates grad_fn
objects. The backward pass (c.backward()
) traverses this structure in reverse.
Representation of the computational graph for
c = (a * 3).mean()
. Rectangles are tensors, ellipses are operations. Edges show data flow.grad_fn
links created tensors to their generating operations, forming the backward path.
When you call .backward()
on a scalar tensor (like c
in our example, or typically a loss value), the autograd engine starts traversing the graph backwards from that tensor.
grad_fn
(MeanBackward0
for c
).c
) with respect to its inputs (b
).grad_fn
objects. So, the gradient computed for b
is passed to MulBackward0
.MulBackward0
computes the gradient with respect to its input (a
).a
is a leaf node (grad_fn
is None
) and has requires_grad=True
, the computed gradient is accumulated in a.grad
.This process continues until all paths reach leaf nodes or tensors that do not require gradients. The computational graph provides the roadmap for this chain rule application.
Understanding the computational graph is not just theoretical. It informs how you structure your models, debug gradient issues (e.g., None
gradients often mean a part of the graph was disconnected or didn't require gradients), and implement custom operations with their own backward passes, as we will see later in this chapter. It's the invisible machinery enabling PyTorch's automatic differentiation.
© 2025 ApX Machine Learning