You've learned how PyTorch's autograd engine dynamically builds a computational graph and traverses it backward to compute gradients, like ∂w∂L. This is the foundation for training most neural networks. However, certain advanced techniques require computing gradients of gradients, known as higher-order gradients.
Consider a function f(x). Its first derivative is f′(x)=dxdf. The second derivative is f′′(x)=dx2d2f, which is simply the derivative of the first derivative. Similarly, we can compute third-order, fourth-order, and higher-order derivatives. In the context of multi-variable functions like neural network losses L(θ) with parameters θ, we often deal with partial derivatives. The first-order gradients form the gradient vector ∇L. Higher-order derivatives involve structures like the Hessian matrix (matrix of second-order partial derivatives, ∇2L) or even higher-order tensors.
PyTorch's autograd engine is capable of handling these computations. While the standard .backward()
method is primarily designed for first-order gradients, the functional interface torch.autograd.grad
provides the flexibility needed for higher-order differentiation.
Computing higher-order gradients is essential for several advanced applications:
torch.autograd.grad
The primary tool for computing higher-order gradients in PyTorch is torch.autograd.grad
. Unlike the tensor.backward()
method, which implicitly computes gradients for all leaf nodes requiring gradients, torch.autograd.grad
is more explicit.
Its basic signature looks like this:
torch.autograd.grad(
outputs, # Scalar or Tensor(s) to be differentiated
inputs, # Tensor(s) w.r.t. which the gradient is computed
grad_outputs=None, # Gradient of the loss w.r.t. 'outputs' (for vector-Jacobian product)
retain_graph=None, # If True, graph is kept; otherwise freed.
create_graph=False, # If True, construct graph for the gradient computation itself
allow_unused=False
)
The critical parameter for higher-order gradients is create_graph=True
. When you compute first-order gradients using torch.autograd.grad
with create_graph=True
, PyTorch not only calculates the gradients but also builds the necessary graph structure that allows you to differentiate through this gradient computation later. If create_graph=False
(the default), the gradient calculation is treated as a terminal operation; the resulting gradients are just tensors without any history connecting them back to the original parameters through the differentiation process.
Let's look at a simple example. Suppose we have y=x3. We want to compute dxdy=3x2 and dx2d2y=6x.
import torch
# Input tensor requires gradients
x = torch.tensor([2.0], requires_grad=True)
# First computation: y = x^3
y = x**3
print(f"y = {y.item()}")
# Compute first derivative: dy/dx
# Use create_graph=True to allow computing higher-order gradients
grad_y_x = torch.autograd.grad(outputs=y, inputs=x, create_graph=True)[0]
print(f"dy/dx at x={x.item()}: {grad_y_x.item()}") # Should be 3 * (2^2) = 12
# grad_y_x is now a tensor with its own computation graph
print(f"Gradient tensor requires_grad: {grad_y_x.requires_grad}")
# Compute second derivative: d^2y/dx^2 = d/dx (dy/dx)
# We differentiate the *first gradient* (grad_y_x) w.r.t x
# No need for create_graph=True here unless we want third-order gradients
grad2_y_x2 = torch.autograd.grad(outputs=grad_y_x, inputs=x)[0]
print(f"d^2y/dx^2 at x={x.item()}: {grad2_y_x2.item()}") # Should be 6 * 2 = 12
# Check requires_grad status of the second derivative
print(f"Second derivative tensor requires_grad: {grad2_y_x2.requires_grad}")
Notice that grad_y_x
has requires_grad=True
because we specified create_graph=True
during its computation. This allows us to call torch.autograd.grad
again with grad_y_x
as the output. The final grad2_y_x2
has requires_grad=False
because we didn't specify create_graph=True
in the second call.
When create_graph=True
is used, the backward pass itself adds nodes to the computational graph.
Consider y=x2, so dxdy=2x.
x
-> pow(2)
-> y
create_graph=False
): Computes gradient (2x) and returns it as a new tensor detached from the graph used to compute it.create_graph=True
): Computes gradient (2x), but adds operations to the graph representing how this gradient was calculated. Conceptually: x
-> pow(2)
-> y
; grad_y
-> MulBackward
(using saved x
) -> grad_x
. The output grad_x
is attached to this extended graph.The diagram contrasts the result of
torch.autograd.grad
withcreate_graph=False
(middle) andcreate_graph=True
(right). Withcreate_graph=True
, the computed gradientgrad_x
remains connected to the graph via the gradient computation operation (PowBackward
), allowing further differentiation.
Let's compute the Hessian-vector product (HVP) for a simple function f(w1,w2)=w12sin(w2). The gradient is ∇f=[∂w1∂f,∂w2∂f]=[2w1sin(w2),w12cos(w2)]. The Hessian is ∇2f=(∂w12∂2f∂w2∂w1∂2f∂w1∂w2∂2f∂w22∂2f)=(2sin(w2)2w1cos(w2)2w1cos(w2)−w12sin(w2)).
We want to compute (∇2f)v for some vector v without explicitly forming ∇2f. We can achieve this using two torch.autograd.grad
calls. The key insight is that (∇2f)v=∇(∇f⋅v), where ∇f⋅v is the dot product (a scalar).
import torch
w = torch.tensor([1.0, torch.pi / 2.0], requires_grad=True) # w1=1, w2=pi/2
v = torch.tensor([0.5, 1.0]) # An arbitrary vector
# Define the function
f = w[0]**2 * torch.sin(w[1])
# Compute first gradient: grad_f = nabla f
grad_f = torch.autograd.grad(f, w, create_graph=True)[0]
# Expected grad_f: [2*1*sin(pi/2), 1^2*cos(pi/2)] = [2, 0]
print(f"Gradient (nabla f): {grad_f}")
# Compute the dot product: grad_f_dot_v = (nabla f) . v
# This operation needs to be part of the graph for the second differentiation
grad_f_dot_v = torch.dot(grad_f, v)
print(f"Dot product (nabla f . v): {grad_f_dot_v}") # Expected: 2*0.5 + 0*1 = 1.0
# Compute the gradient of the dot product w.r.t w: nabla (nabla f . v)
# This gives the Hessian-vector product (nabla^2 f) v
hvp = torch.autograd.grad(grad_f_dot_v, w)[0]
# Expected Hessian: [[2*sin(pi/2), 2*1*cos(pi/2)], [2*1*cos(pi/2), -1^2*sin(pi/2)]]
# = [[2, 0], [0, -1]]
# Expected HVP: [[2, 0], [0, -1]] @ [0.5, 1.0] = [2*0.5 + 0*1, 0*0.5 + (-1)*1] = [1.0, -1.0]
print(f"Hessian-vector product (nabla^2 f) v: {hvp}")
This technique avoids materializing the potentially huge Hessian matrix, requiring only vector products and gradient computations, which is much more memory-efficient for large models.
torch.autograd.grad
with create_graph=True
essentially doubles the depth of the graph that needs to be traversed in subsequent backward passes.Understanding how to compute higher-order gradients using torch.autograd.grad
and the create_graph=True
flag unlocks a range of advanced capabilities in optimization, model analysis, and the implementation of complex algorithms like meta-learning and physics-informed modeling within the PyTorch framework.
© 2025 ApX Machine Learning