Automatic differentiation (AD) is the mechanism that allows frameworks like JAX to compute derivatives of potentially complex functions accurately and efficiently. While you might be familiar with using jax.grad
to get the gradient of a function with respect to its first argument, this convenience is built upon more fundamental operations corresponding to the two primary modes of AD: forward mode and reverse mode. Understanding these modes is important for tackling more advanced differentiation tasks, optimizing performance, and defining custom derivative rules.
This section provides a refresher on these two modes. AD operates by decomposing a function into a sequence of elementary operations (like addition, multiplication, sin, exp) and applying the chain rule repeatedly at the level of these operations. Unlike numerical differentiation, which approximates derivatives using finite differences (and suffers from truncation and round-off errors), or symbolic differentiation, which can lead to exponentially large expressions, AD computes exact derivatives (up to machine precision) with manageable computational cost.
Forward-mode AD computes derivatives alongside the original function's evaluation. Think of it as carrying the "sensitivity" of the intermediate variables with respect to the inputs forward through the computation graph.
Consider a function f:Rn→Rm. Forward mode is particularly adept at computing Jacobian-vector products (JVPs). The Jacobian matrix J of f at a point x contains all the partial derivatives ∂xj∂fi. A JVP computes the product of this Jacobian J with a "tangent" vector v, where v typically represents a direction in the input space:
JVP(f,x,v)=J⋅v=dαdf(x+αv)α=0This operation tells you how the output of the function f changes when the input x is perturbed infinitesimally in the direction v. Forward mode calculates this JVP efficiently without explicitly forming the full Jacobian matrix J. It does this by propagating the directional derivative v through the chain rule alongside the function evaluation.
The computational cost of computing one JVP using forward mode is roughly proportional to the cost of computing the original function f(x) itself, regardless of the output dimension m. However, if you need the full Jacobian, you would need to compute n JVPs, one for each standard basis vector in the input space. Therefore, forward mode is computationally efficient when the number of inputs n is significantly smaller than the number of outputs m (a "tall" Jacobian).
Reverse-mode AD, often synonymous with backpropagation in the context of neural networks, works differently. It first performs a forward pass to compute the function's output value f(x) and potentially stores intermediate values (activations). Then, it performs a backward pass, starting from the output and propagating derivatives backward through the computation graph using the chain rule.
Reverse mode excels at computing vector-Jacobian products (VJPs). A VJP computes the product of a "cotangent" vector vT with the Jacobian J:
VJP(f,x,v)=vT⋅JThis operation takes a vector v representing sensitivities or gradients at the output of the function and computes the corresponding sensitivities or gradients with respect to the inputs of the function.
The computational cost of computing one VJP using reverse mode is roughly proportional to the cost of computing the original function f(x), regardless of the input dimension n. This is a significant advantage when dealing with functions that have many inputs and few outputs, which is the common scenario in machine learning model training where the function is the loss computation (many parameters as input, a single scalar loss as output).
To compute the gradient of a scalar-valued function L(x) (where m=1), we only need one VJP calculation. Since L(x) is scalar, its Jacobian is a row vector (the gradient transpose, ∇L(x)T). The VJP with the scalar "sensitivity" v=1 is then:
VJP(L,x,1)=1⋅∇L(x)T=∇L(x)TThe result is the transpose of the gradient vector. This is why reverse mode is the foundation for jax.grad
. Computing the full Jacobian using reverse mode would require m VJP computations, one for each standard basis vector in the output space.
Flow of Forward-Mode (JVP) and Reverse-Mode (VJP) Automatic Differentiation. Forward mode computes function values and derivatives simultaneously in one forward pass. Reverse mode computes function values in a forward pass, then computes derivatives in a separate backward pass.
jax.grad
.Having reviewed these foundational concepts, we can now explore how JAX provides explicit control over JVPs and VJPs through jax.jvp
and jax.vjp
, enabling more advanced differentiation techniques beyond the standard gradient calculation.
© 2025 ApX Machine Learning