While PyTorch's autograd engine automatically handles differentiation for a wide range of built-in operations, situations arise where you need more control or need to define the gradient for an operation unknown to PyTorch. This might happen when:
For these scenarios, PyTorch provides a mechanism to define your own differentiable operations by subclassing torch.autograd.Function. This class allows you to specify precisely how the forward computation is performed and how gradients should be calculated during the backward pass.
It's important to distinguish torch.autograd.Function from torch.nn.Module. While nn.Module typically represents layers in a neural network containing parameters (torch.nn.Parameter) and can be composed of other modules or functions, autograd.Function defines a single, specific computational operation and its gradient. It does not hold parameters itself.
To create a custom operation, you define a class that inherits from torch.autograd.Function. The core of the forward computation lies in implementing a static method called forward.
import torch
class MyLinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor, weight, bias=None):
# ctx is a context object to save information for backward pass
# input_tensor, weight, bias are the inputs to the function
# Perform the operation
output = input_tensor.mm(weight.t()) # Matrix multiplication
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
# Save tensors needed for backward pass
# We need input_tensor and weight to compute gradients
ctx.save_for_backward(input_tensor, weight, bias)
return output
Important aspects of the forward method:
@staticmethod. It doesn't operate on an instance of the class but defines the operation itself.ctx Argument: The first argument is always ctx, a context object. Its primary role is to act as a bridge between the forward and backward passes. You use ctx to store any tensors or information computed during forward that will be needed later to calculate gradients in backward.ctx, you list the input arguments your function accepts. These can be tensors or other Python objects.forward, you implement the logic of your operation using standard PyTorch tensor operations or potentially calls to external libraries.ctx.save_for_backward(*tensors): This is the essential method for saving tensors that are required for the gradient calculation. Only save what's necessary to avoid unnecessary memory consumption. PyTorch handles the bookkeeping to ensure these tensors are available in the backward pass. You can also save non-tensor attributes directly onto ctx (e.g., ctx.some_flag = True), which can be retrieved later in backward.The counterpart to forward is the static backward method. This method defines how to compute the gradients of the loss function with respect to the inputs of the forward method, given the gradients of the loss with respect to the outputs of the forward method.
import torch
class MyLinearFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor, weight, bias=None):
output = input_tensor.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
# Save input_tensor and weight. Bias is also saved if provided.
saved_tensors = [input_tensor, weight]
if bias is not None:
saved_tensors.append(bias)
ctx.save_for_backward(*saved_tensors)
return output
@staticmethod
def backward(ctx, grad_output):
# grad_output is the gradient of the loss w.r.t. the forward output
# We need to compute gradients w.r.t. forward's inputs:
# input_tensor, weight, bias
# Retrieve saved tensors
saved_tensors = ctx.saved_tensors
input_tensor = saved_tensors[0]
weight = saved_tensors[1]
bias = saved_tensors[2] if len(saved_tensors) > 2 else None
# Calculate gradients using the chain rule
# dL/d(input) = dL/d(output) * d(output)/d(input)
# d(output)/d(input) = weight^T
grad_input = grad_output.mm(weight)
# dL/d(weight) = dL/d(output) * d(output)/d(weight)
# d(output)/d(weight) = input^T
grad_weight = grad_output.t().mm(input_tensor)
# dL/d(bias) = dL/d(output) * d(output)/d(bias)
# d(output)/d(bias) = 1
grad_bias = None
if bias is not None:
# Sum gradients across the batch dimension
grad_bias = grad_output.sum(0)
# Return gradients for each input argument of forward, in the same order
# Return None for inputs that don't require gradients (like ctx)
# or non-Tensor inputs.
# The number of return values MUST match the number of forward inputs.
return grad_input, grad_weight, grad_bias
Important aspects of the backward method:
forward, it must be a @staticmethod.ctx Argument: The first argument is again the context object ctx, used to retrieve saved information.grad_output Arguments: Following ctx, it receives arguments representing the gradient of the final loss with respect to each output of the forward method. If forward returned a single tensor, backward receives a single grad_output tensor. If forward returned multiple tensors, backward receives multiple gradient tensors, one for each output, in the corresponding order. These gradients (∂output∂L) are provided by the autograd engine during the backward pass.ctx.saved_tensors: You retrieve the tensors saved in forward using the saved_tensors attribute of ctx. They are returned as a tuple in the same order they were saved. Any non-tensor attributes saved directly onto ctx can also be accessed (e.g., ctx.some_flag).grad_output (∂output∂L) and the tensors retrieved from ctx (or the original inputs, if saved) to compute ∂input∂output.backward method must return a gradient for every input argument of the forward method, in the exact same order.
requires_grad=True), return the computed gradient tensor.requires_grad=False), you can return None. PyTorch often optimizes by not saving tensors needed only to compute gradients for inputs that don't require them.None.backward must exactly match the number of arguments accepted by forward (excluding ctx).The diagram illustrates the flow: inputs go into
forward, which computes the output and saves necessary tensors viactx. Later, the gradient w.r.t the output (grad_output) flows intobackward, which retrieves the saved tensors fromctxand computes the gradients w.r.t the original inputs.
You don't call the forward or backward methods directly. Instead, you use the apply class method. This method takes the same arguments as your forward function (excluding ctx), executes the forward pass, and sets up the necessary bookkeeping so that autograd knows to call your backward method when needed.
# Example Usage
input_features = 10
output_features = 5
batch_size = 3
# Create tensors that require gradients
x = torch.randn(batch_size, input_features, requires_grad=True)
w = torch.randn(output_features, input_features, requires_grad=True) # Note: shape for mm(weight.t())
b = torch.randn(output_features, requires_grad=True)
# Apply the custom function
# Use MyLinearFunction.apply, NOT MyLinearFunction.forward directly
y = MyLinearFunction.apply(x, w, b)
# Example: Calculate a dummy loss and backpropagate
loss = y.mean()
loss.backward()
# Check gradients (optional)
print("Gradient for x:", x.grad is not None)
print("Gradient for w:", w.grad is not None)
print("Gradient for b:", b.grad is not None)
Calling MyLinearFunction.apply(x, w, b) performs the forward computation defined in MyLinearFunction.forward and registers the operation in the computational graph. When loss.backward() is called later, the autograd engine encounters this custom operation and invokes MyLinearFunction.backward with the appropriate grad_output.
gradcheckImplementing the backward pass correctly is critical and prone to errors. PyTorch provides a utility function, torch.autograd.gradcheck, to help verify your implementation. gradcheck numerically computes the gradients by slightly perturbing each input (finite differences) and compares these numerical gradients to the analytical gradients computed by your backward function.
from torch.autograd import gradcheck
# Create inputs for gradcheck. Often requires double precision for stability.
x_check = torch.randn(batch_size, input_features, dtype=torch.double, requires_grad=True)
w_check = torch.randn(output_features, input_features, dtype=torch.double, requires_grad=True)
b_check = torch.randn(output_features, dtype=torch.double, requires_grad=True)
# Define the function to test (using apply)
test_func = MyLinearFunction.apply
# Perform the check
# inputs is a tuple containing the arguments to the function
inputs = (x_check, w_check, b_check)
is_correct = gradcheck(test_func, inputs, eps=1e-6, atol=1e-4)
print("Gradient check passed:", is_correct)
# Example with bias=None (optional argument handling)
x_check_no_bias = torch.randn(batch_size, input_features, dtype=torch.double, requires_grad=True)
w_check_no_bias = torch.randn(output_features, input_features, dtype=torch.double, requires_grad=True)
# Need a small wrapper if function signature changes based on inputs
def test_func_no_bias(x, w):
return MyLinearFunction.apply(x, w, None)
inputs_no_bias = (x_check_no_bias, w_check_no_bias)
is_correct_no_bias = gradcheck(test_func_no_bias, inputs_no_bias, eps=1e-6, atol=1e-4)
print("Gradient check (no bias) passed:", is_correct_no_bias)
Using gradcheck is highly recommended whenever you implement a custom autograd.Function. It catches many common errors in gradient formulas. Note that gradcheck typically requires inputs to be torch.double for sufficient numerical precision and may be slow for large inputs. It's usually performed on small, representative test cases.
.apply(): Always invoke your custom function using YourFunction.apply(...). Calling forward directly will bypass the autograd mechanism.ctx.save_for_backward stores tensors, consuming memory until the backward pass is complete. Only save what is strictly necessary for the gradient calculation. If intermediate values are cheap to recompute, you might do so in backward instead of saving them.ctx.save_for_backward within the backward method is generally unsafe. It's often safer to work with copies or allocate new tensors for results.backward method must themselves be differentiable. PyTorch's autograd engine can handle this automatically if you use standard differentiable PyTorch operations within backward. Creating custom functions that correctly support higher-order gradients requires careful implementation.Mastering torch.autograd.Function provides fine-grained control over differentiation, enabling the implementation of complex models and optimization strategies outside the standard library's offerings. It is a fundamental tool for advanced PyTorch development and research.
Was this section helpful?
Function implementations.autograd.Function in PyTorch.© 2026 ApX Machine LearningEngineered with