jax.jitjitjitjitgradjax.gradgrad of grad)jax.value_and_grad)vmapjax.vmapin_axes, out_axes)vmapvmap with jit and gradvmappmapjax.pmapin_axes, out_axes)lax.psum, lax.pmean, etc.)pmap with other Transformationspmapped Functionsgrad of grad)The jax.grad function takes a Python function f that computes a scalar value and returns a new Python function that computes the gradient ∇f. A significant aspect of JAX's design is that its transformations, including jax.grad, are themselves implemented in Python and operate on Python functions. This means we can apply transformations to the output of other transformations.
What happens if we apply jax.grad to the gradient function produced by jax.grad? We get a function that computes the second derivative.
Let's consider a simple scalar function:
f(x)=x3Its first derivative is f′(x)=3x2, and its second derivative is f′′(x)=6x.
We can compute the first derivative function using jax.grad:
import jax
import jax.numpy as jnp
def f(x):
return x**3
# Get the function that computes the first derivative
grad_f = jax.grad(f)
# Evaluate the original function and its first derivative at x=2.0
x_val = 2.0
print(f"f({x_val}) =", f(x_val))
print(f"f'({x_val}) =", grad_f(x_val))
Running this gives:
f(2.0) = 8.0
f'(2.0) = 12.0
This matches our analytical calculation: f(2)=23=8 and f′(2)=3(22)=12.
Now, since grad_f is just another Python function (that happens to compute the gradient), we can differentiate it as well:
# Get the function that computes the second derivative
grad_grad_f = jax.grad(grad_f)
# Evaluate the second derivative at x=2.0
print(f"f''({x_val}) =", grad_grad_f(x_val))
The output is:
f''(2.0) = 12.0
This also matches our analytical result: f′′(2)=6(2)=12.
We can express this more compactly by nesting the calls to jax.grad:
# Directly define the second derivative function
grad_grad_f_nested = jax.grad(jax.grad(f))
print(f"f''({x_val}) using nested grad =", grad_grad_f_nested(x_val))
f''(2.0) using nested grad = 12.0
This process can be continued to compute third, fourth, or even higher-order derivatives, limited only by computational cost and numerical stability.
This nesting applies equally to functions with multiple arguments. When dealing with multivariate scalar functions, applying grad twice can be used to compute elements of the Hessian matrix. The Hessian matrix H of a scalar-valued function f(x1,x2,...,xn) is the square matrix of second-order partial derivatives:
For instance, consider f(x,y)=x2y+y3. Let's find ∂x2∂2f. We first differentiate with respect to x (treating y as constant), then differentiate the result with respect to x again.
def f_multi(x, y):
return x**2 * y + y**3
# First derivative with respect to x (arg 0)
grad_f_wrt_x = jax.grad(f_multi, argnums=0)
# Second derivative: differentiate grad_f_wrt_x with respect to x (its first arg, index 0)
grad_grad_f_wrt_xx = jax.grad(grad_f_wrt_x, argnums=0)
# Analytically:
# df/dx = 2xy
# d^2f/dx^2 = 2y
x_val, y_val = 2.0, 3.0
print(f"d^2f/dx^2 at ({x_val}, {y_val}) =", grad_grad_f_wrt_xx(x_val, y_val))
print(f"Analytical result (2*y):", 2 * y_val)
d^2f/dx^2 at (2.0, 3.0) = 6.0
Analytical result (2*y): 6.0
To compute mixed partial derivatives, like ∂y∂x∂2f, we change the argnums in the second differentiation:
# Second derivative: differentiate grad_f_wrt_x with respect to y (its second arg, index 1)
grad_grad_f_wrt_xy = jax.grad(grad_f_wrt_x, argnums=1)
# Analytically:
# df/dx = 2xy
# d^2f/dydx = 2x
print(f"d^2f/dydx at ({x_val}, {y_val}) =", grad_grad_f_wrt_xy(x_val, y_val))
print(f"Analytical result (2*x):", 2 * x_val)
d^2f/dydx at (2.0, 3.0) = 4.0
Analytical result (2*x): 4.0
While nesting grad works perfectly well for computing specific second derivatives or even higher-order derivatives, JAX also provides convenience functions like jax.hessian for directly computing the full Hessian matrix, which might be more efficient if you need all second partial derivatives. However, understanding the grad(grad(...)) composition is fundamental to grasping the composability of JAX transformations.
This ability to arbitrarily compose grad highlights the functional nature of JAX. Each transformation takes a function and returns a new function, ready to be used or transformed further.
Was this section helpful?
jax.grad for higher-order differentiation and providing practical code examples.© 2026 ApX Machine LearningEngineered with