While Jacobian-vector products (JVPs) and vector-Jacobian products (VJPs) are computationally efficient for many applications, particularly gradient-based optimization where only the product is needed, sometimes you require the complete Jacobian or Hessian matrix. This might be necessary for certain second-order optimization algorithms, sensitivity analysis, or for understanding the local geometry of a function. JAX provides convenient functions to compute these full matrices, building upon the underlying JVP and VJP mechanisms. However, be mindful that computing and storing these full matrices can be significantly more computationally expensive and memory-intensive than computing products, especially for high-dimensional functions common in machine learning.
The Jacobian matrix J of a vector-valued function f:Rn→Rm contains all the first-order partial derivatives. Its entry Jij represents the partial derivative of the i-th output component fi with respect to the j-th input component xj:
Jij=∂xj∂fiThe full Jacobian is therefore an m×n matrix. JAX offers two primary ways to compute it, based on forward-mode and reverse-mode automatic differentiation.
jax.jacfwd
)The function jax.jacfwd
computes the Jacobian using forward-mode automatic differentiation. Conceptually, it computes the JVP for each standard basis vector corresponding to the input dimensions. If the input x∈Rn, it computes J⋅ej for each basis vector ej (where ej has a 1 at index j and 0s elsewhere). The result J⋅ej gives the j-th column of the Jacobian matrix J.
import jax
import jax.numpy as jnp
# Example function: R^3 -> R^2
def func(x):
return jnp.array([x[0]**2 * x[1], jnp.sin(x[2])])
# Input point
x_in = jnp.array([1.0, 2.0, jnp.pi / 2])
# Compute Jacobian using forward-mode AD
jacobian_fwd = jax.jacfwd(func)(x_in)
print("Input:", x_in)
print("Output:", func(x_in))
print("Jacobian (jacfwd):\n", jacobian_fwd)
# Expected shape: (2, 3) -> (output_dim, input_dim)
Forward-mode AD generally has a computational cost proportional to the number of inputs. Therefore, jax.jacfwd
can be more efficient when the number of inputs (n) is smaller than the number of outputs (m), i.e., for "tall" Jacobians. However, since it computes the Jacobian column by column, its advantage is often seen when the function itself is cheaper to evaluate in forward mode multiple times.
jax.jacrev
)Alternatively, jax.jacrev
computes the Jacobian using reverse-mode automatic differentiation. This approach leverages VJPs. Conceptually, it computes the VJP for each standard basis vector corresponding to the output dimensions. If the output f(x)∈Rm, it computes eiT⋅J for each basis vector ei (where ei has a 1 at index i and 0s elsewhere). The result eiT⋅J gives the i-th row of the Jacobian matrix J.
import jax
import jax.numpy as jnp
# Example function: R^3 -> R^2
def func(x):
return jnp.array([x[0]**2 * x[1], jnp.sin(x[2])])
# Input point
x_in = jnp.array([1.0, 2.0, jnp.pi / 2])
# Compute Jacobian using reverse-mode AD
jacobian_rev = jax.jacrev(func)(x_in)
print("Input:", x_in)
print("Output:", func(x_in))
print("Jacobian (jacrev):\n", jacobian_rev)
# Expected shape: (2, 3) -> (output_dim, input_dim)
Reverse-mode AD typically has a computational cost proportional to the number of outputs. Therefore, jax.jacrev
is often more efficient when the number of outputs (m) is smaller than the number of inputs (n), i.e., for "wide" Jacobians. This is a common scenario in machine learning where a loss function maps high-dimensional parameters to a scalar loss (m=1). Computing the gradient (jax.grad
) is a special case of jax.jacrev
for scalar output functions.
vmap
(Illustrative)You can also construct the Jacobian manually by applying vmap
to jvp
or vjp
(or grad
for scalar outputs). While jacfwd
and jacrev
are usually preferred for their optimized implementations, understanding the vmap
approach can provide insight.
For a function f:Rn→Rm, applying vmap
to jvp
across standard basis input tangents yields the columns of the Jacobian:
import jax
import jax.numpy as jnp
# Example function: R^3 -> R^2
def func(x):
return jnp.array([x[0]**2 * x[1], jnp.sin(x[2])])
# Input point
x_in = jnp.array([1.0, 2.0, jnp.pi / 2])
n = x_in.shape[0] # Input dimension
# Standard basis vectors for input tangents
basis_vectors_in = jnp.eye(n)
# Compute Jacobian column-by-column using vmap over jvp
# jax.jvp requires primal_in and tangent_in
# We fix primal_in and map over tangent_in
primals_out, jac_cols = jax.vmap(lambda tangent: jax.jvp(func, (x_in,), (tangent,)),
in_axes=0)(basis_vectors_in)
# Transpose to get the standard m x n Jacobian matrix
jacobian_vmap_jvp = jac_cols.T
print("Jacobian (vmap + jvp):\n", jacobian_vmap_jvp)
Similarly, applying vmap
to vjp
across standard basis output cotangents yields the rows of the Jacobian:
import jax
import jax.numpy as jnp
# Example function: R^3 -> R^2
def func(x):
return jnp.array([x[0]**2 * x[1], jnp.sin(x[2])])
# Input point
x_in = jnp.array([1.0, 2.0, jnp.pi / 2])
primals_out, vjp_fun = jax.vjp(func, x_in)
m = primals_out.shape[0] # Output dimension
# Standard basis vectors for output cotangents
basis_vectors_out = jnp.eye(m)
# Compute Jacobian row-by-row using vmap over vjp
jac_rows = jax.vmap(vjp_fun, in_axes=0)(basis_vectors_out)
# The result is already the m x n Jacobian matrix (each output from vmap is a row)
# Note: vjp_fun returns a tuple, we need the first element
jacobian_vmap_vjp = jac_rows[0]
print("Jacobian (vmap + vjp):\n", jacobian_vmap_vjp)
For a scalar-valued function (f:Rn→R), the Jacobian is simply the gradient (a row vector, or its transpose, the gradient vector). You could compute it using jax.grad
directly, or via jax.jacrev
(which is generally preferred over jax.jacfwd
for scalar outputs).
jacfwd
and jacrev
jax.jacrev
when the function output dimension (m) is significantly smaller than the input dimension (n). This is common for loss functions in machine learning (m=1).jax.jacfwd
when the function input dimension (n) is significantly smaller than the output dimension (m).%timeit
or more advanced tools covered in Chapter 2) is the best way to determine the optimal choice for a specific use case.The Hessian matrix H of a scalar-valued function f:Rn→R contains all the second-order partial derivatives. Its entry Hij is given by:
Hij=∂xi∂xj∂2fThe Hessian is an n×n matrix. For functions with continuous second derivatives (which is common in ML contexts), the Hessian matrix is symmetric (Hij=Hji).
JAX computes the Hessian by composing differentiation transforms. Specifically, the Hessian is the Jacobian of the gradient function.
jax.hessian
The most straightforward way to compute the Hessian is using jax.hessian
:
import jax
import jax.numpy as jnp
# Example scalar function: R^2 -> R
def scalar_func(x):
# f(x, y) = x^2 * y + y^3
return x[0]**2 * x[1] + x[1]**3
# Input point
x_in = jnp.array([1.0, 2.0])
# Compute the Hessian matrix
hessian_matrix = jax.hessian(scalar_func)(x_in)
print("Input:", x_in)
print("Output:", scalar_func(x_in))
print("Hessian:\n", hessian_matrix)
# Expected shape: (2, 2) -> (input_dim, input_dim)
Internally, jax.hessian(f)
is typically implemented as jax.jacfwd(jax.grad(f))
. It first computes the gradient function (g=∇f) using reverse-mode AD (jax.grad
), and then computes the Jacobian of this gradient function using forward-mode AD (jax.jacfwd
). You could also compute it as jax.jacrev(jax.grad(f))
. The choice between jacfwd
and jacrev
for the outer call follows the same logic as for Jacobians: jacfwd
might be slightly preferred here as the gradient function g:Rn→Rn has equal input and output dimensions.
Computing the full Hessian involves calculating O(n2) second derivatives. This can become computationally prohibitive very quickly as the input dimension n grows. Storing the n×n matrix also requires significant memory (O(n2)).
For many applications, particularly in optimization, the full Hessian matrix is not required. Instead, algorithms often rely on Hessian-vector products (HvPs), which compute H⋅v for a given vector v. HvPs can be computed much more efficiently without explicitly forming H, typically using a combination of forward- and reverse-mode AD. For example, one way to compute H⋅v=∇x((∇xf(x))⋅v) involves one forward-mode pass and one reverse-mode pass, roughly costing the same as two gradient calculations.
import jax
import jax.numpy as jnp
# Example scalar function: R^2 -> R
def scalar_func(x):
# f(x, y) = x^2 * y + y^3
return x[0]**2 * x[1] + x[1]**3
# Input point and vector
x_in = jnp.array([1.0, 2.0])
v = jnp.array([0.5, -0.5])
# Method 1: Compute full Hessian then product (inefficient)
hessian_matrix = jax.hessian(scalar_func)(x_in)
hvp_explicit = hessian_matrix @ v
print("Full Hessian:\n", hessian_matrix)
print("Explicit HvP:", hvp_explicit)
# Method 2: Efficient Hessian-vector product
# Compute gradient function first
grad_f = jax.grad(scalar_func)
# Compute JVP of the gradient function
# jax.jvp(grad_f, (x_in,), (v,)) returns (grad_f(x_in), H @ v)
_, hvp_efficient = jax.jvp(grad_f, (x_in,), (v,))
print("Efficient HvP:", hvp_efficient)
Therefore, while JAX provides jax.hessian
for convenience, always consider if a Hessian-vector product is sufficient for your task, as it offers substantial performance benefits for large n.
Computing full Jacobians and Hessians is appropriate when:
In many large-scale machine learning scenarios, direct computation and storage of these matrices are infeasible. Techniques relying on JVPs, VJPs, and Hessian-vector products are the standard approach, allowing differentiation calculations to scale effectively. Understanding how jax.jacfwd
, jax.jacrev
, and jax.hessian
work provides valuable insight into JAX's differentiation capabilities, even if you primarily use gradient calculations or vector products in practice.
© 2025 ApX Machine Learning