Let's solidify your understanding of jax.grad
with some practical exercises. We'll start with simple functions and gradually apply the concepts covered in this chapter, such as differentiating with respect to specific arguments and calculating higher-order derivatives.
Make sure you have JAX installed and imported:
import jax
import jax.numpy as jnp
import numpy as np # Often useful for comparison or initial data
# Enable 64-bit precision for better comparison with analytical results if needed
from jax.config import config
config.update("jax_enable_x64", True)
print(f"JAX version: {jax.__version__}")
print(f"Default backend: {jax.default_backend()}")
Consider the polynomial function f(x)=3x2+2x+5. Analytically, its derivative is f′(x)=6x+2. Let's verify this using jax.grad
.
Define the function:
def poly_func(x):
"""A simple polynomial function."""
return 3 * x**2 + 2 * x + 5
Create the gradient function: Use jax.grad
to get a new function that computes the gradient of poly_func
.
grad_poly_func = jax.grad(poly_func)
Compute the gradient at a specific point: Let's evaluate the gradient at x=4.
x_value = 4.0 # Use float for differentiation
gradient_at_x = grad_poly_func(x_value)
print(f"Gradient of f(x) at x = {x_value}: {gradient_at_x}")
# Verify analytically: f'(4) = 6*4 + 2 = 24 + 2 = 26
analytical_gradient = 6 * x_value + 2
print(f"Analytical gradient at x = {x_value}: {analytical_gradient}")
You should see that the output from jax.grad
matches the analytical result (26.0).
Now, let's work with a function of two variables: g(x,y)=x3y+2x2. We want to compute the partial derivatives ∂x∂g and ∂y∂g.
Analytically: ∂x∂g=3x2y+4x ∂y∂g=x3
Define the function:
def multi_var_func(x, y):
"""Function with two variables."""
return x**3 * y + 2 * x**2
Compute the gradient with respect to the first argument (x): By default, jax.grad
differentiates with respect to the first argument (argnums=0
).
grad_g_wrt_x = jax.grad(multi_var_func, argnums=0)
x_val = 2.0
y_val = 3.0
gradient_x = grad_g_wrt_x(x_val, y_val)
print(f"Gradient w.r.t. x at ({x_val}, {y_val}): {gradient_x}")
# Verify analytically: 3*(2^2)*3 + 4*2 = 3*4*3 + 8 = 36 + 8 = 44
analytical_grad_x = 3 * x_val**2 * y_val + 4 * x_val
print(f"Analytical gradient w.r.t. x: {analytical_grad_x}")
Compute the gradient with respect to the second argument (y): Use argnums=1
.
grad_g_wrt_y = jax.grad(multi_var_func, argnums=1)
gradient_y = grad_g_wrt_y(x_val, y_val)
print(f"\nGradient w.r.t. y at ({x_val}, {y_val}): {gradient_y}")
# Verify analytically: 2^3 = 8
analytical_grad_y = x_val**3
print(f"Analytical gradient w.r.t. y: {analytical_grad_y}")
Compute the gradient with respect to both arguments: Use argnums=(0, 1)
. This returns a tuple of gradients.
grad_g_wrt_xy = jax.grad(multi_var_func, argnums=(0, 1))
gradient_xy = grad_g_wrt_xy(x_val, y_val)
print(f"\nGradient w.r.t. (x, y) at ({x_val}, {y_val}): {gradient_xy}")
print(f"Analytical gradient w.r.t. (x, y): ({analytical_grad_x}, {analytical_grad_y})")
The results should match the analytical partial derivatives.
Let's find the second derivative of our original polynomial f(x)=3x2+2x+5. The first derivative is f′(x)=6x+2, and the second derivative is f′′(x)=6.
Compute the second derivative: Apply jax.grad
twice.
# We already have grad_poly_func = jax.grad(poly_func)
grad_grad_poly_func = jax.grad(grad_poly_func) # Apply grad again
x_value = 4.0 # The point doesn't matter for f''(x) = 6
second_derivative = grad_grad_poly_func(x_value)
print(f"\nSecond derivative of f(x) at x = {x_value}: {second_derivative}")
print(f"Analytical second derivative: 6.0")
The result should be 6.0, independent of the input x_value
.
jax.value_and_grad
Often in optimization, you need both the function's value (e.g., the loss) and its gradient. jax.value_and_grad
computes both simultaneously, which is more efficient than calling the function and its gradient function separately.
Let's use the function h(w)=(w−5)2, a simple quadratic often used to represent a basic loss function where we want to find w that minimizes h(w) (the minimum is at w=5). The gradient is h′(w)=2(w−5).
Define the function:
def simple_loss(w):
"""A simple quadratic loss function."""
return (w - 5.0)**2
Create the value-and-gradient function:
value_and_grad_loss = jax.value_and_grad(simple_loss)
Evaluate at a specific point: Let's try w=2.0.
w_value = 2.0
value, gradient = value_and_grad_loss(w_value)
print(f"\nUsing jax.value_and_grad at w = {w_value}:")
print(f" Function value h(w): {value}")
print(f" Gradient h'(w): {gradient}")
# Verify analytically:
# h(2) = (2 - 5)^2 = (-3)^2 = 9
# h'(2) = 2 * (2 - 5) = 2 * (-3) = -6
analytical_value = (w_value - 5.0)**2
analytical_gradient_h = 2 * (w_value - 5.0)
print(f"Analytical value: {analytical_value}")
print(f"Analytical gradient: {analytical_gradient_h}")
You'll get both the value (9.0) and the gradient (-6.0) in a single call.
jax.numpy
jax.grad
works seamlessly with functions built using jax.numpy
. Let's compute the gradient of a function involving jnp.sum
and jnp.sin
.
Consider k(v)=∑isin(vi), where v is a vector. The gradient ∇k(v) is a vector where the j-th element is ∂vj∂k=cos(vj).
Define the function using jnp
:
def sum_of_sines(v):
"""Calculates the sum of sines of vector elements."""
return jnp.sum(jnp.sin(v))
Create the gradient function:
grad_sum_of_sines = jax.grad(sum_of_sines)
Evaluate with a sample vector:
v_vector = jnp.array([0.0, jnp.pi/2, jnp.pi])
gradient_vector = grad_sum_of_sines(v_vector)
print(f"\nGradient of sum_of_sines for v = {v_vector}:")
print(f" Gradient: {gradient_vector}")
# Verify analytically: The gradient should be [cos(0), cos(pi/2), cos(pi)]
analytical_gradient_k = jnp.cos(v_vector)
print(f"Analytical gradient: {analytical_gradient_k}")
The output gradient vector should be [1. 0. -1.]
, matching jnp.cos(v_vector)
.
These exercises demonstrate the practical application of jax.grad
for various scenarios, from simple polynomials to functions involving jax.numpy
operations and multiple arguments. Mastering these patterns is fundamental for using JAX in optimization and machine learning tasks.
© 2025 ApX Machine Learning