While PyTorch's autograd
engine automatically handles differentiation for a vast range of built-in operations, situations arise where you need more control or need to define the gradient for an operation unknown to PyTorch. This might happen when:
For these scenarios, PyTorch provides a mechanism to define your own differentiable operations by subclassing torch.autograd.Function
. This class allows you to specify precisely how the forward computation is performed and how gradients should be calculated during the backward pass.
It's important to distinguish torch.autograd.Function
from torch.nn.Module
. While nn.Module
typically represents layers in a neural network containing parameters (torch.nn.Parameter
) and can be composed of other modules or functions, autograd.Function
defines a single, specific computational operation and its gradient. It does not hold parameters itself.
To create a custom operation, you define a class that inherits from torch.autograd.Function
. The core of the forward computation lies in implementing a static method called forward
.
import torch
class MyLinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor, weight, bias=None):
# ctx is a context object to save information for backward pass
# input_tensor, weight, bias are the inputs to the function
# Perform the operation
output = input_tensor.mm(weight.t()) # Matrix multiplication
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
# Save tensors needed for backward pass
# We need input_tensor and weight to compute gradients
ctx.save_for_backward(input_tensor, weight, bias)
return output
Key aspects of the forward
method:
@staticmethod
. It doesn't operate on an instance of the class but defines the operation itself.ctx
Argument: The first argument is always ctx
, a context object. Its primary role is to act as a bridge between the forward
and backward
passes. You use ctx
to store any tensors or information computed during forward
that will be needed later to calculate gradients in backward
.ctx
, you list the input arguments your function accepts. These can be tensors or other Python objects.forward
, you implement the logic of your operation using standard PyTorch tensor operations or potentially calls to external libraries.ctx.save_for_backward(*tensors)
: This is the essential method for saving tensors that are required for the gradient calculation. Only save what's necessary to avoid unnecessary memory consumption. PyTorch handles the bookkeeping to ensure these tensors are available in the backward
pass. You can also save non-tensor attributes directly onto ctx
(e.g., ctx.some_flag = True
), which can be retrieved later in backward
.The counterpart to forward
is the static backward
method. This method defines how to compute the gradients of the loss function with respect to the inputs of the forward
method, given the gradients of the loss with respect to the outputs of the forward
method.
import torch
class MyLinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor, weight, bias=None):
output = input_tensor.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
# Save input_tensor and weight. Bias is also saved if provided.
saved_tensors = [input_tensor, weight]
if bias is not None:
saved_tensors.append(bias)
ctx.save_for_backward(*saved_tensors)
return output
@staticmethod
def backward(ctx, grad_output):
# grad_output is the gradient of the loss w.r.t. the forward output
# We need to compute gradients w.r.t. forward's inputs:
# input_tensor, weight, bias
# Retrieve saved tensors
saved_tensors = ctx.saved_tensors
input_tensor = saved_tensors[0]
weight = saved_tensors[1]
bias = saved_tensors[2] if len(saved_tensors) > 2 else None
# Calculate gradients using the chain rule
# dL/d(input) = dL/d(output) * d(output)/d(input)
# d(output)/d(input) = weight^T
grad_input = grad_output.mm(weight)
# dL/d(weight) = dL/d(output) * d(output)/d(weight)
# d(output)/d(weight) = input^T
grad_weight = grad_output.t().mm(input_tensor)
# dL/d(bias) = dL/d(output) * d(output)/d(bias)
# d(output)/d(bias) = 1
grad_bias = None
if bias is not None:
# Sum gradients across the batch dimension
grad_bias = grad_output.sum(0)
# Return gradients for each input argument of forward, in the same order
# Return None for inputs that don't require gradients (like ctx)
# or non-Tensor inputs.
# The number of return values MUST match the number of forward inputs.
return grad_input, grad_weight, grad_bias
Key aspects of the backward
method:
forward
, it must be a @staticmethod
.ctx
Argument: The first argument is again the context object ctx
, used to retrieve saved information.grad_output
Arguments: Following ctx
, it receives arguments representing the gradient of the final loss with respect to each output of the forward
method. If forward
returned a single tensor, backward
receives a single grad_output
tensor. If forward
returned multiple tensors, backward
receives multiple gradient tensors, one for each output, in the corresponding order. These gradients (∂output∂L) are provided by the autograd engine during the backward pass.ctx.saved_tensors
: You retrieve the tensors saved in forward
using the saved_tensors
attribute of ctx
. They are returned as a tuple in the same order they were saved. Any non-tensor attributes saved directly onto ctx
can also be accessed (e.g., ctx.some_flag
).grad_output
(∂output∂L) and the tensors retrieved from ctx
(or the original inputs, if saved) to compute ∂input∂output.backward
method must return a gradient for every input argument of the forward
method, in the exact same order.
requires_grad=True
), return the computed gradient tensor.requires_grad=False
), you can return None
. PyTorch often optimizes by not saving tensors needed only to compute gradients for inputs that don't require them.None
.backward
must exactly match the number of arguments accepted by forward
(excluding ctx
).The diagram illustrates the flow: inputs go into
forward
, which computes the output and saves necessary tensors viactx
. Later, the gradient w.r.t the output (grad_output
) flows intobackward
, which retrieves the saved tensors fromctx
and computes the gradients w.r.t the original inputs.
You don't call the forward
or backward
methods directly. Instead, you use the apply
class method. This method takes the same arguments as your forward
function (excluding ctx
), executes the forward
pass, and sets up the necessary bookkeeping so that autograd
knows to call your backward
method when needed.
# Example Usage
input_features = 10
output_features = 5
batch_size = 3
# Create tensors that require gradients
x = torch.randn(batch_size, input_features, requires_grad=True)
w = torch.randn(output_features, input_features, requires_grad=True) # Note: shape for mm(weight.t())
b = torch.randn(output_features, requires_grad=True)
# Apply the custom function
# Use MyLinearFunction.apply, NOT MyLinearFunction.forward directly
y = MyLinearFunction.apply(x, w, b)
# Example: Calculate a dummy loss and backpropagate
loss = y.mean()
loss.backward()
# Check gradients (optional)
print("Gradient for x:", x.grad is not None)
print("Gradient for w:", w.grad is not None)
print("Gradient for b:", b.grad is not None)
Calling MyLinearFunction.apply(x, w, b)
performs the forward computation defined in MyLinearFunction.forward
and registers the operation in the computational graph. When loss.backward()
is called later, the autograd engine encounters this custom operation and invokes MyLinearFunction.backward
with the appropriate grad_output
.
gradcheck
Implementing the backward
pass correctly is critical and prone to errors. PyTorch provides a utility function, torch.autograd.gradcheck
, to help verify your implementation. gradcheck
numerically computes the gradients by slightly perturbing each input (finite differences) and compares these numerical gradients to the analytical gradients computed by your backward
function.
from torch.autograd import gradcheck
# Create inputs for gradcheck. Often requires double precision for stability.
x_check = torch.randn(batch_size, input_features, dtype=torch.double, requires_grad=True)
w_check = torch.randn(output_features, input_features, dtype=torch.double, requires_grad=True)
b_check = torch.randn(output_features, dtype=torch.double, requires_grad=True)
# Define the function to test (using apply)
test_func = MyLinearFunction.apply
# Perform the check
# inputs is a tuple containing the arguments to the function
inputs = (x_check, w_check, b_check)
is_correct = gradcheck(test_func, inputs, eps=1e-6, atol=1e-4)
print("Gradient check passed:", is_correct)
# Example with bias=None (optional argument handling)
x_check_no_bias = torch.randn(batch_size, input_features, dtype=torch.double, requires_grad=True)
w_check_no_bias = torch.randn(output_features, input_features, dtype=torch.double, requires_grad=True)
# Need a small wrapper if function signature changes based on inputs
def test_func_no_bias(x, w):
return MyLinearFunction.apply(x, w, None)
inputs_no_bias = (x_check_no_bias, w_check_no_bias)
is_correct_no_bias = gradcheck(test_func_no_bias, inputs_no_bias, eps=1e-6, atol=1e-4)
print("Gradient check (no bias) passed:", is_correct_no_bias)
Using gradcheck
is highly recommended whenever you implement a custom autograd.Function
. It catches many common errors in gradient formulas. Note that gradcheck
typically requires inputs to be torch.double
for sufficient numerical precision and may be slow for large inputs. It's usually performed on small, representative test cases.
.apply()
: Always invoke your custom function using YourFunction.apply(...)
. Calling forward
directly will bypass the autograd mechanism.ctx.save_for_backward
stores tensors, consuming memory until the backward pass is complete. Only save what is strictly necessary for the gradient calculation. If intermediate values are cheap to recompute, you might do so in backward
instead of saving them.ctx.save_for_backward
within the backward
method is generally unsafe. It's often safer to work with copies or allocate new tensors for results.backward
method must themselves be differentiable. PyTorch's autograd engine can handle this automatically if you use standard differentiable PyTorch operations within backward
. Creating custom functions that correctly support higher-order gradients requires careful implementation.Mastering torch.autograd.Function
provides fine-grained control over differentiation, enabling the implementation of complex models and optimization strategies beyond the standard library's offerings. It is a fundamental tool for advanced PyTorch development and research.
© 2025 ApX Machine Learning