Automatic Differentiation (AD) is the technique that powers jax.grad. It's a set of methods for numerically evaluating the derivative of a function specified by a computer program. Unlike symbolic differentiation (which manipulates mathematical expressions) or numerical differentiation (which uses finite differences and can suffer from approximation errors), AD computes exact derivatives efficiently by systematically applying the chain rule of calculus at the level of elementary operations.
There are two primary modes of AD: forward mode and reverse mode. jax.grad primarily utilizes reverse-mode automatic differentiation, often referred to as backpropagation in the machine learning community. Let's understand how it operates.
At its core, reverse-mode AD views any computation as a sequence of elementary operations (like addition, multiplication, sin, cos, etc.) applied to some input values. This sequence can be represented as a computational graph, where nodes represent either input variables or operations, and directed edges represent the flow of data.
Consider a simple function: . We can break this down into intermediate steps:
The computational graph looks like this: .
A simple computational graph for . Data flows from input through intermediate operation to the final output .
When you execute the function normally, say f(2.0), you perform a forward pass through this graph:
During this forward pass, AD systems like JAX not only compute the final value but also typically store the intermediate values (like ) and the structure of the graph. These are needed for the next phase.
The goal is to compute the gradient of the final output with respect to the initial input , which is . Reverse mode achieves this by propagating derivatives backward through the graph, starting from the output.
Initialization: The derivative of the output with respect to itself is trivially 1. We represent this sensitivity or adjoint as . This is the initial "gradient signal" entering the graph from the end.
Backward Step (sin node): We move backward from to . We need to find how a change in affects , which is the local derivative . We then use the chain rule to find the gradient signal arriving at :
This represents the sensitivity of the final output to changes in the intermediate variable . Since , the local derivative . So, . Using the value stored from the forward pass, .
Backward Step (square node): We continue backward from to . We need the local derivative . We use the chain rule again, multiplying the incoming gradient signal (the sensitivity of w.r.t. ) by the local derivative :
Since , the local derivative . So, . Using and the input value (also potentially stored or recomputed), . This is our desired gradient .
The reverse pass essentially calculates how sensitive the final output is to each intermediate variable and input, working backward from the output and using the chain rule at each step.
Flow of forward pass (computing values) and reverse pass (computing gradients using the chain rule). The reverse pass requires values (x, a) computed during the forward pass and propagates sensitivities backward.
Why is reverse mode the default for jax.grad and prevalent in machine learning? Consider a typical neural network loss function: , where are many weight matrices/vectors (parameters), is input data, and is the target. This function takes potentially millions of inputs (parameters and data) but produces only a single scalar output (the loss ).
Reverse mode's computational cost is relatively insensitive to the number of inputs. It requires one forward pass through the computational graph (to compute intermediate values) and one backward pass (to compute gradients). The total cost is typically a small constant factor (e.g., 2-4x) times the cost of evaluating the original function. This makes it extremely efficient for computing the gradient with respect to all parameters simultaneously, which is exactly what's needed for gradient-based optimization algorithms like gradient descent.
Forward mode, in contrast, computes derivatives with respect to one input at a time. Its cost scales linearly with the number of inputs you need gradients for. While useful in some contexts, it becomes computationally prohibitive for training large models with millions of parameters.
When you apply jax.grad to your Python function, JAX performs these steps under the hood:
jax.value_and_grad, this pass also computes the final function output).You, as the user, define the forward computation using familiar Python and NumPy-like syntax. JAX takes care of deriving and executing the efficient gradient computation via reverse-mode AD based on that definition. This separation of concerns allows you to focus on the model logic while relying on JAX for performant differentiation.
Was this section helpful?
jax.grad, including reverse mode specifics.© 2026 ApX Machine LearningEngineered with