Automatic Differentiation (AD) is the technique that powers jax.grad
. It's a set of methods for numerically evaluating the derivative of a function specified by a computer program. Unlike symbolic differentiation (which manipulates mathematical expressions) or numerical differentiation (which uses finite differences and can suffer from approximation errors), AD computes exact derivatives efficiently by systematically applying the chain rule of calculus at the level of elementary operations.
There are two primary modes of AD: forward mode and reverse mode. jax.grad
primarily utilizes reverse-mode automatic differentiation, often referred to as backpropagation in the machine learning community. Let's understand how it operates.
At its core, reverse-mode AD views any computation as a sequence of elementary operations (like addition, multiplication, sin, cos, etc.) applied to some input values. This sequence can be represented as a computational graph, where nodes represent either input variables or operations, and directed edges represent the flow of data.
Consider a simple function: f(x)=sin(x2). We can break this down into intermediate steps:
The computational graph looks like this: x→square→a→sin→y.
A simple computational graph for f(x)=sin(x2). Data flows from input x through intermediate operation a to the final output y.
When you execute the function normally, say f(2.0)
, you perform a forward pass through this graph:
During this forward pass, AD systems like JAX not only compute the final value y but also typically store the intermediate values (like a=4.0) and the structure of the graph. These are needed for the next phase.
The goal is to compute the gradient of the final output y with respect to the initial input x, which is dy/dx. Reverse mode achieves this by propagating derivatives backward through the graph, starting from the output.
Initialization: The derivative of the output with respect to itself is trivially 1. We represent this sensitivity or adjoint as yˉ=dy/dy=1. This is the initial "gradient signal" entering the graph from the end.
Backward Step (sin node): We move backward from y to a. We need to find how a change in a affects y, which is the local derivative dy/da. We then use the chain rule to find the gradient signal arriving at a:
aˉ=dady=dydydady=yˉdadyThis aˉ represents the sensitivity of the final output y to changes in the intermediate variable a. Since y=sin(a), the local derivative dy/da=cos(a). So, aˉ=1×cos(a). Using the value a=4.0 stored from the forward pass, aˉ=cos(4.0)≈−0.654.
Backward Step (square node): We continue backward from a to x. We need the local derivative da/dx. We use the chain rule again, multiplying the incoming gradient signal aˉ (the sensitivity of y w.r.t. a) by the local derivative da/dx:
xˉ=dxdy=dadydxda=aˉdxdaSince a=x2, the local derivative da/dx=2x. So, xˉ=aˉ×(2x). Using aˉ≈−0.654 and the input value x=2.0 (also potentially stored or recomputed), xˉ≈−0.654×(2×2.0)=−0.654×4.0≈−2.616. This xˉ is our desired gradient dy/dx.
The reverse pass essentially calculates how sensitive the final output is to each intermediate variable and input, working backward from the output and using the chain rule at each step.
Flow of forward pass (computing values) and reverse pass (computing gradients using the chain rule). The reverse pass requires values (x, a) computed during the forward pass and propagates sensitivities backward.
Why is reverse mode the default for jax.grad
and prevalent in machine learning? Consider a typical neural network loss function: L=f(W1,W2,...,Wn,x,y), where Wi are many weight matrices/vectors (parameters), x is input data, and y is the target. This function takes potentially millions of inputs (parameters and data) but produces only a single scalar output (the loss L).
Reverse mode's computational cost is relatively insensitive to the number of inputs. It requires one forward pass through the computational graph (to compute intermediate values) and one backward pass (to compute gradients). The total cost is typically a small constant factor (e.g., 2-4x) times the cost of evaluating the original function. This makes it extremely efficient for computing the gradient ∇L with respect to all parameters simultaneously, which is exactly what's needed for gradient-based optimization algorithms like gradient descent.
Forward mode, in contrast, computes derivatives with respect to one input at a time. Its cost scales linearly with the number of inputs you need gradients for. While useful in some contexts, it becomes computationally prohibitive for training large models with millions of parameters.
When you apply jax.grad
to your Python function, JAX performs these steps under the hood:
jax.value_and_grad
, this pass also computes the final function output).You, as the user, define the forward computation using familiar Python and NumPy-like syntax. JAX takes care of deriving and executing the efficient gradient computation via reverse-mode AD based on that definition. This separation of concerns allows you to focus on the model logic while relying on JAX for performant differentiation.
© 2025 ApX Machine Learning