Reverse-mode automatic differentiation is the workhorse behind gradient descent in deep learning. It efficiently computes the gradient of a scalar output (like a loss function) with respect to potentially millions of inputs (model parameters). While jax.grad
provides a user-friendly way to get these gradients, understanding the underlying mechanism, the Vector-Jacobian Product (VJP), unlocks more advanced capabilities. jax.vjp
is the JAX function that computes VJPs.
Recall that for a function f:Rn→Rm, its Jacobian matrix Jf(x) at a point x∈Rn contains all the partial derivatives ∂xj∂fi. The Jacobian is an m×n matrix.
Reverse-mode AD computes the product of a cotangent vector v∈Rm with the Jacobian matrix, yielding vTJf(x). This product is the Vector-Jacobian Product (VJP). The result is a vector in Rn (or technically, a 1×n row vector).
Why is this useful? Consider a typical machine learning scenario where we have a composition of functions leading to a scalar loss L. Let f be the last function in this chain, mapping from some intermediate representation y∈Rm to the final scalar loss L∈R. So, L=f(y). The gradient ∇yL is a row vector ∂y∂L∈R1×m. If y=g(x) where x∈Rn, the chain rule tells us ∇xL=∇yL⋅Jg(x). This is exactly a VJP! The vector vT is the gradient (or cotangent) flowing backward from the output, and Jg(x) is the Jacobian of the function g whose input gradients we want.
Reverse mode works by propagating these cotangent vectors backward through the computation graph, starting with the initial cotangent dLdL=1.0 for a scalar loss function L.
jax.vjp
JAX provides direct access to VJP computation through jax.vjp
. Its basic signature is:
y, vjp_fun = jax.vjp(fun, *primals)
Let's break this down:
fun
: The Python callable (function) to differentiate.*primals
: The input arguments (x
, y
, etc.) at which to evaluate the function and compute the VJP. These are the points where the Jacobian is implicitly calculated.y
: The output of the function evaluated at the primals, i.e., y = fun(*primals)
. This is the result of the forward pass.vjp_fun
: A new function (a closure) that computes the VJP. This function takes one argument, the cotangent vector v
, which must have the same structure (shape and dtype) as the output y
. When called as vjp_fun(v)
, it returns the VJP: vTJfun(primals). The result is a tuple containing the computed gradients with respect to each primal input.Let's start with a simple function f(x,y)=x2sin(y).
import jax
import jax.numpy as jnp
def f(x, y):
return x**2 * jnp.sin(y)
# Define primal inputs
x_primal = 3.0
y_primal = jnp.pi / 2.0
# Compute VJP
y_out, vjp_fun = jax.vjp(f, x_primal, y_primal)
print(f"Forward pass output y: {y_out}")
# The output y is scalar, so the cotangent v is also scalar.
# Let's use v = 1.0, which corresponds to finding d(1.0 * y) / dx and d(1.0 * y) / dy
# This is equivalent to finding the gradient of y.
cotangent_v = 1.0
primal_grads = vjp_fun(cotangent_v)
print(f"Gradient wrt x: {primal_grads[0]}") # Should be 2 * x * sin(y) = 2 * 3.0 * sin(pi/2) = 6.0
print(f"Gradient wrt y: {primal_grads[1]}") # Should be x^2 * cos(y) = 3.0^2 * cos(pi/2) = 0.0
# Compare with jax.grad
grad_f = jax.grad(f, argnums=(0, 1)) # Get gradients wrt both x and y
grads_via_grad = grad_f(x_primal, y_primal)
print(f"Gradients via jax.grad: {grads_via_grad}")
In this case, calling vjp_fun(1.0)
gives the same result as jax.grad(f, argnums=(0, 1))(x, y)
. This highlights the fundamental relationship: jax.grad
for scalar functions is essentially a convenience wrapper around jax.vjp
that automatically provides the initial cotangent of 1.0
.
Now consider a function with a vector output: g(x)=(x02,x13), where x=(x0,x1).
import jax
import jax.numpy as jnp
def g(x):
return jnp.array([x[0]**2, x[1]**3])
# Define primal input
x_primal = jnp.array([2.0, 3.0])
# Compute VJP
y_out, vjp_fun = jax.vjp(g, x_primal)
print(f"Forward pass output y: {y_out}") # Should be [4., 27.]
# The output y is a vector of shape (2,), so the cotangent v must also have shape (2,)
cotangent_v = jnp.array([10.0, 20.0]) # An arbitrary cotangent vector
# Compute the VJP: v^T J
primal_grad = vjp_fun(cotangent_v)
print(f"VJP result (gradient wrt x): {primal_grad}")
# Expected:
# Jacobian J = [[2*x0, 0], [0, 3*x1^2]] = [[4, 0], [0, 27]]
# v^T J = [10, 20] * [[4, 0], [0, 27]] = [10*4, 20*27] = [40, 540]
# jax.vjp returns a tuple, even for a single primal, so result is ([40., 540.],)
Here, vjp_fun
takes the vector cotangent_v
and multiplies it by the Jacobian of g
evaluated at x_primal
to produce the resulting gradient contributions back to x
. The shape of the cotangent must match the shape of the function's output y_out
.
jax.vjp
Directly?While jax.grad
suffices for standard gradient descent where you need the gradient of a scalar loss, jax.vjp
offers more control:
grad
calls.jax.custom_vjp
allows you to define bespoke VJP rules for specific functions. Understanding jax.vjp
is necessary for this.jax.vjp
clarifies how reverse-mode AD operates and how jax.grad
is constructed.In summary, jax.vjp
provides a lower-level interface to JAX's powerful reverse-mode automatic differentiation system. It computes Vector-Jacobian Products, which represent the core operation of backpropagation. While jax.grad
handles the common case of scalar loss gradients, mastering jax.vjp
gives you the tools to tackle more complex differentiation tasks and understand the system more deeply.
© 2025 ApX Machine Learning