As discussed, calculating gradients is fundamental for many tasks, particularly in optimizing machine learning models. JAX provides a powerful and elegant way to perform automatic differentiation through function transformations. The primary tool for this is jax.grad
.
jax.grad
is itself a function. Its core purpose is to transform a numerical Python function that returns a scalar value into a new function that computes the gradient of the original function. Think of it like this: you give jax.grad
a function f, and it returns a function ∇f.
Let's look at a simple example. Consider the mathematical function f(x)=x3. We know from calculus that its derivative (gradient in 1D) is f′(x)=3x2. Let's see how to compute this using JAX.
First, we define the function f using standard Python syntax, potentially using jax.numpy
for numerical operations:
import jax
import jax.numpy as jnp
def f(x):
"""Computes x cubed."""
return x**3
# Test the original function
print(f"f(2.0) = {f(2.0)}")
f(2.0) = 8.0
Now, we apply the jax.grad
transformation to our function f
:
# Get the gradient function using jax.grad
grad_f = jax.grad(f)
print(f"Type of f: {type(f)}")
print(f"Type of grad_f: {type(grad_f)}")
Type of f: <class 'function'>
Type of grad_f: <class 'function'>
Notice that jax.grad(f)
returns a new Python function, which we've named grad_f
. This new function grad_f
is ready to compute the gradient. To get the gradient value at a specific point, say x=2.0, we call grad_f
with that value:
# Calculate the gradient of f at x=2.0
gradient_at_2 = grad_f(2.0)
print(f"Gradient of f at x=2.0: {gradient_at_2}")
Gradient of f at x=2.0: 12.0
The result is 12.0, which matches our manual calculation: f′(x)=3x2, so f′(2.0)=3×(2.0)2=3×4.0=12.0.
It's important to remember the input and output types:
jax.grad
with a Python function that takes one or more arguments and returns a single scalar value.jax.grad
returns a new Python function.What happens if the function takes multiple arguments? Let's define g(x,y)=x2×y:
def g(x, y):
"""Computes x squared times y."""
return (x**2) * y
# Get the gradient function for g
grad_g = jax.grad(g)
# Calculate the gradient at (x=2.0, y=3.0)
# By default, this is the gradient with respect to the first argument (x)
gradient_g_wrt_x = grad_g(2.0, 3.0)
print(f"g(2.0, 3.0) = {g(2.0, 3.0)}")
print(f"Gradient of g w.r.t x at (2.0, 3.0): {gradient_g_wrt_x}")
g(2.0, 3.0) = 12.0
Gradient of g w.r.t x at (2.0, 3.0): 12.0
Here, grad_g(2.0, 3.0)
computes the partial derivative ∂x∂g evaluated at x=2.0,y=3.0. Manually, ∂x∂g=2xy. Evaluating at (2.0,3.0) gives 2×2.0×3.0=12.0, matching the JAX output. We will explore how to compute gradients with respect to other arguments in a later section.
A essential requirement for jax.grad
is that the function you are differentiating must return a single scalar number (like an integer or a float, though floats are typical for differentiation). If your function returns an array, a tuple, or any non-scalar output, using jax.grad
directly will result in an error. This aligns with the mathematical definition of a gradient, which applies to scalar fields (functions mapping vectors or numbers to scalars). Techniques for differentiating functions with multiple outputs exist (like jax.jacobian
), often building upon jax.grad
.
One final point: automatic differentiation typically operates on floating-point numbers. While JAX might sometimes allow integer inputs, you'll generally want to ensure the inputs to the functions you differentiate (and thus to the gradient functions) are floats to get meaningful derivative results.
# Example using jax.numpy within the function
def h(x):
"""Computes sin(x)."""
return jnp.sin(x)
grad_h = jax.grad(h)
# Calculate gradient at pi/2 (where cos(x) = 0)
# Use jnp.pi for precision and ensure float input
gradient_h_at_pi_half = grad_h(jnp.pi / 2.0)
print(f"h(pi/2) = {h(jnp.pi / 2.0)}")
print(f"Gradient of h at x=pi/2: {gradient_h_at_pi_half}")
h(pi/2) = 1.0
Gradient of h at x=pi/2: 0.0
As expected, the derivative of sin(x) is cos(x), and cos(π/2)=0.
In summary, jax.grad
is the foundational tool in JAX for automatic differentiation. It transforms a scalar-output Python function into a new function that computes its gradient with respect to the first argument. This transformation handles the complexities of differentiation automatically, allowing you to focus on defining your computations.
© 2025 ApX Machine Learning