JAX's functional transformations, including grad
, jvp
, and vjp
, are designed to be composable. This means you can apply transformations to functions that are themselves the result of other transformations. This powerful feature allows for the computation of derivatives beyond the first order, opening doors to sophisticated analysis and optimization techniques.
The most straightforward way to think about higher-order derivatives is by applying differentiation functions repeatedly. For a simple scalar function f:R→R, computing the second derivative is as simple as applying jax.grad
twice:
import jax
import jax.numpy as jnp
def f(x):
return x**3 + 2*x**2 - 3*x + 1
# First derivative
grad_f = jax.grad(f)
# Second derivative
grad_grad_f = jax.grad(grad_f) # Or equivalently, jax.grad(jax.grad(f))
# Evaluate at x = 2.0
x_val = 2.0
first_deriv = grad_f(x_val)
second_deriv = grad_grad_f(x_val)
print(f"f(x) = x^3 + 2x^2 - 3x + 1")
print(f"f'(x) = 3x^2 + 4x - 3")
print(f"f''(x) = 6x + 4")
print(f"f'({x_val}) = {first_deriv}") # Expected: 3*(2^2) + 4*2 - 3 = 12 + 8 - 3 = 17
print(f"f''({x_val}) = {second_deriv}") # Expected: 6*2 + 4 = 12 + 4 = 16
This composition works exactly as you might expect from calculus. grad(f)
returns a function that computes the first derivative, and applying grad
to that function yields another function computing the second derivative. This principle extends readily to third or even higher-order derivatives for scalar functions.
For multivariate functions f:Rn→R, simply composing grad
twice doesn't directly yield the Hessian matrix (the matrix of second-order partial derivatives). Instead, grad(grad(f))
would compute the gradient of the gradient function, which isn't quite the Hessian structure.
Often in optimization algorithms (like Newton's method or truncated Newton methods), we don't need the full Hessian matrix H. Instead, we need to compute the product of the Hessian matrix with a vector v, known as a Hessian-vector product (HVP): Hv. Computing HVPs can be significantly more efficient than forming the full Hessian, especially when n is large.
JAX allows computing HVPs efficiently by composing forward-mode (jvp
) and reverse-mode (vjp
or grad
) differentiation. Recall that jax.grad(f)
is essentially built upon jax.vjp
.
There are two primary ways to compute Hv:
Forward-over-Reverse: Compute the JVP of the gradient function. Let g(x)=∇f(x). We want to compute the JVP of g with the tangent vector v.
JVP(g,x,v)=∂x∂g(x)v=∂x∂(∇f(x))v=H(x)vIn JAX, this translates to:
def hvp_forward_over_reverse(f, primals, tangents):
# Compute H @ v using jvp(grad(f))
return jax.jvp(jax.grad(f), primals, tangents)[1]
The [1]
selects the output part of the jvp
result, which corresponds to the product term.
Reverse-over-Forward: Compute the VJP of the JVP function applied to a fixed vector v. This is slightly less direct. Consider the function h(x)=(∇f(x))⊤v=jnp.dot(∇f(x),v). The gradient of h(x) is ∇h(x)=H(x)v. In JAX, this translates to:
def hvp_reverse_over_reverse(f, primals, tangents):
# Compute H @ v using grad(lambda x: jnp.dot(grad(f)(x), v))
x, = primals
v, = tangents
# Need lambda to close over v
return jax.grad(lambda x_lambda: jnp.dot(jax.grad(f)(x_lambda), v))(x)
Note the use of a lambda
function to properly capture the vector v
within the gradient computation.
Generally, the forward-over-reverse approach (jvp(grad(f), ...)
) is preferred for HVPs as it's often computationally more efficient and maps more directly to the mathematical definition.
Let's see an example:
import jax
import jax.numpy as jnp
def func(x):
# f(x,y) = x^2 * y + y^3
return x[0]**2 * x[1] + x[1]**3
x_primal = jnp.array([1.0, 2.0])
v_tangent = jnp.array([1.0, 0.0]) # Vector to multiply Hessian with
# Using jvp(grad(f)) - Forward-over-Reverse
hvp_val = jax.jvp(jax.grad(func), (x_primal,), (v_tangent,))[1]
print(f"Function: f(x,y) = x^2 * y + y^3")
# Gradient: nabla f = [2xy, x^2 + 3y^2]
# Hessian: H = [[2y, 2x], [2x, 6y]]
# At (1, 2): H = [[4, 2], [2, 12]]
# H @ v = [[4, 2], [2, 12]] @ [1, 0] = [4, 2]
print(f"Primals (x): {x_primal}")
print(f"Tangents (v): {v_tangent}")
print(f"Hessian-vector product (H @ v): {hvp_val}") # Expected: [4., 2.]
While HVPs are efficient for many applications, sometimes the explicit representation of the full Hessian matrix H is required. The Hessian is the Jacobian of the gradient function: H(x)=J∇f(x).
We can leverage JAX's functions for computing Jacobians (jacfwd
and jacrev
) applied to the gradient function (jax.grad(f)
) to obtain the Hessian.
Using jax.hessian
: JAX provides a convenience function jax.hessian
that directly computes the Hessian matrix.
hessian_matrix = jax.hessian(func)(x_primal)
print("\nFull Hessian using jax.hessian:")
print(hessian_matrix)
# Expected: [[4., 2.], [2., 12.]]
Internally, jax.hessian
often combines forward and reverse mode differentiation (like jacfwd(jacrev(f))
or jacrev(jacfwd(f))
) for efficiency, similar to how jacfwd
and jacrev
compute Jacobians.
Using jacfwd(grad(f))
: This computes the Hessian by applying forward-mode automatic differentiation (jacfwd
) to the gradient function. Conceptually, it computes each column of the Hessian via a JVP.
hessian_jacfwd = jax.jacfwd(jax.grad(func))(x_primal)
print("\nFull Hessian using jacfwd(grad(f)):")
print(hessian_jacfwd)
Using jacrev(grad(f))
: This computes the Hessian by applying reverse-mode automatic differentiation (jacrev
) to the gradient function. Conceptually, it computes each row of the Hessian via a VJP.
hessian_jacrev = jax.jacrev(jax.grad(func))(x_primal)
print("\nFull Hessian using jacrev(grad(f)):")
print(hessian_jacrev)
The choice between jacfwd(grad(f))
and jacrev(grad(f))
can have performance implications depending on the relative dimensions involved, similar to the choice between jacfwd
and jacrev
for first-order Jacobians. jax.hessian
attempts to choose a reasonable strategy automatically.
The principle of composition extends naturally. You can compute third derivatives, fourth derivatives, and various mixed partial derivatives by further composing grad
, jvp
, jacfwd
, and jacrev
as needed. For instance, the gradient of the Hessian could be computed, although direct applications in standard machine learning are less common than first and second-order derivatives.
Example: Third derivative of the initial scalar function f
:
# Third derivative
grad_grad_grad_f = jax.grad(grad_grad_f) # Or jax.grad(jax.grad(jax.grad(f)))
third_deriv = grad_grad_grad_f(x_val)
print(f"\nf'''(x) = 6")
print(f"f'''({x_val}) = {third_deriv}") # Expected: 6.0
The ability to compose differentiation functions arbitrarily is a defining characteristic of JAX's autodiff system, providing substantial flexibility for implementing advanced numerical and machine learning algorithms.
© 2025 ApX Machine Learning