You've seen how jax.grad
takes a Python function f that computes a scalar value and returns a new Python function that computes the gradient ∇f. A significant aspect of JAX's design is that its transformations, including jax.grad
, are themselves implemented in Python and operate on Python functions. This means we can apply transformations to the output of other transformations.
What happens if we apply jax.grad
to the gradient function produced by jax.grad
? We get a function that computes the second derivative.
Let's consider a simple scalar function:
f(x)=x3Its first derivative is f′(x)=3x2, and its second derivative is f′′(x)=6x.
We can compute the first derivative function using jax.grad
:
import jax
import jax.numpy as jnp
def f(x):
return x**3
# Get the function that computes the first derivative
grad_f = jax.grad(f)
# Evaluate the original function and its first derivative at x=2.0
x_val = 2.0
print(f"f({x_val}) =", f(x_val))
print(f"f'({x_val}) =", grad_f(x_val))
Running this gives:
f(2.0) = 8.0
f'(2.0) = 12.0
This matches our analytical calculation: f(2)=23=8 and f′(2)=3(22)=12.
Now, since grad_f
is just another Python function (that happens to compute the gradient), we can differentiate it as well:
# Get the function that computes the second derivative
grad_grad_f = jax.grad(grad_f)
# Evaluate the second derivative at x=2.0
print(f"f''({x_val}) =", grad_grad_f(x_val))
The output is:
f''(2.0) = 12.0
This also matches our analytical result: f′′(2)=6(2)=12.
We can express this more compactly by nesting the calls to jax.grad
:
# Directly define the second derivative function
grad_grad_f_nested = jax.grad(jax.grad(f))
print(f"f''({x_val}) using nested grad =", grad_grad_f_nested(x_val))
f''(2.0) using nested grad = 12.0
This process can be continued to compute third, fourth, or even higher-order derivatives, limited only by computational cost and numerical stability.
This nesting applies equally to functions with multiple arguments. When dealing with multivariate scalar functions, applying grad
twice can be used to compute elements of the Hessian matrix. The Hessian matrix H of a scalar-valued function f(x1,x2,...,xn) is the square matrix of second-order partial derivatives:
For instance, consider f(x,y)=x2y+y3. Let's find ∂x2∂2f. We first differentiate with respect to x (treating y as constant), then differentiate the result with respect to x again.
def f_multi(x, y):
return x**2 * y + y**3
# First derivative with respect to x (arg 0)
grad_f_wrt_x = jax.grad(f_multi, argnums=0)
# Second derivative: differentiate grad_f_wrt_x with respect to x (its first arg, index 0)
grad_grad_f_wrt_xx = jax.grad(grad_f_wrt_x, argnums=0)
# Analytically:
# df/dx = 2xy
# d^2f/dx^2 = 2y
x_val, y_val = 2.0, 3.0
print(f"d^2f/dx^2 at ({x_val}, {y_val}) =", grad_grad_f_wrt_xx(x_val, y_val))
print(f"Analytical result (2*y):", 2 * y_val)
d^2f/dx^2 at (2.0, 3.0) = 6.0
Analytical result (2*y): 6.0
To compute mixed partial derivatives, like ∂y∂x∂2f, we change the argnums
in the second differentiation:
# Second derivative: differentiate grad_f_wrt_x with respect to y (its second arg, index 1)
grad_grad_f_wrt_xy = jax.grad(grad_f_wrt_x, argnums=1)
# Analytically:
# df/dx = 2xy
# d^2f/dydx = 2x
print(f"d^2f/dydx at ({x_val}, {y_val}) =", grad_grad_f_wrt_xy(x_val, y_val))
print(f"Analytical result (2*x):", 2 * x_val)
d^2f/dydx at (2.0, 3.0) = 4.0
Analytical result (2*x): 4.0
While nesting grad
works perfectly well for computing specific second derivatives or even higher-order derivatives, JAX also provides convenience functions like jax.hessian
for directly computing the full Hessian matrix, which might be more efficient if you need all second partial derivatives. However, understanding the grad(grad(...))
composition is fundamental to grasping the composability of JAX transformations.
This ability to arbitrarily compose grad
highlights the functional nature of JAX. Each transformation takes a function and returns a new function, ready to be used or transformed further.
© 2025 ApX Machine Learning