Often, when performing optimization or analysis, you need not only the gradient of a function but also the function's original output value at the same point. For instance, in training machine learning models, you typically want to log the loss value while simultaneously computing its gradient to update model parameters.
You could achieve this by making two separate calls: one to your original function to get the value, and another to the jax.grad
-transformed function to get the gradient.
import jax
import jax.numpy as jnp
# Example function
def polynomial(x):
return x**3 + 2*x**2 - 3*x + 1
# Calculate value
x_val = 2.0
value = polynomial(x_val)
# Calculate gradient separately
grad_fn = jax.grad(polynomial)
gradient = grad_fn(x_val)
print(f"Function value at x={x_val}: {value}")
print(f"Gradient at x={x_val}: {gradient}")
# Expected gradient: 3*x**2 + 4*x - 3 = 3*(2**2) + 4*2 - 3 = 12 + 8 - 3 = 17
While this works, it can be inefficient. Recall from the discussion on how reverse-mode autodiff operates that the forward pass often computes intermediate values that are closely related to, or even identical to, the function's final output. Calculating the gradient involves both a forward pass (similar to computing the original value) and a backward pass. Calling the function and its gradient function separately essentially performs the forward pass work twice.
JAX provides a more efficient transformation for this common pattern: jax.value_and_grad
. This function takes your original function and returns a new function that, when called, computes both the original function's value and its gradient in a single, optimized pass.
Here's how you use it:
import jax
import jax.numpy as jnp
# Example function (same as before)
def polynomial(x):
return x**3 + 2*x**2 - 3*x + 1
# Create a function that returns both value and gradient
value_and_grad_fn = jax.value_and_grad(polynomial)
# Call the new function
x_val = 2.0
value, gradient = value_and_grad_fn(x_val)
print(f"Using jax.value_and_grad:")
print(f" Function value: {value}")
print(f" Gradient: {gradient}")
The value_and_grad_fn
returns a tuple where the first element is the result of polynomial(x_val)
and the second element is the result of jax.grad(polynomial)(x_val)
. This avoids redundant computation by sharing the work of the forward pass.
Just like jax.grad
, jax.value_and_grad
accepts the argnums
argument to specify which positional argument(s) to differentiate with respect to.
If argnums
is an integer, the gradient returned will correspond to that single argument.
def multi_arg_func(a, b):
return a**2 * jnp.sin(b)
# Differentiate w.r.t. the first argument (index 0)
value_and_grad_a_fn = jax.value_and_grad(multi_arg_func, argnums=0)
a_val, b_val = 3.0, jnp.pi / 2
value, grad_a = value_and_grad_a_fn(a_val, b_val)
# Expected grad_a: 2*a*sin(b) = 2*3*sin(pi/2) = 6 * 1 = 6
print(f"Value: {value}")
print(f"Gradient w.r.t. 'a': {grad_a}")
If argnums
is a tuple of integers, the gradient returned will be a tuple, with each element corresponding to the gradient with respect to the argument at the specified index.
# Differentiate w.r.t. both arguments (indices 0 and 1)
value_and_grad_ab_fn = jax.value_and_grad(multi_arg_func, argnums=(0, 1))
a_val, b_val = 3.0, jnp.pi / 2
value, (grad_a, grad_b) = value_and_grad_ab_fn(a_val, b_val)
# Expected grad_a: 2*a*sin(b) = 6
# Expected grad_b: a**2*cos(b) = 3**2*cos(pi/2) = 9 * 0 = 0
print(f"\nValue: {value}")
print(f"Gradient w.r.t. 'a': {grad_a}")
print(f"Gradient w.r.t. 'b': {grad_b}")
Using jax.value_and_grad
is a standard practice when implementing optimization algorithms (like gradient descent) where both the loss value and its gradient are needed at each step. It integrates smoothly with other JAX transformations like jit
, allowing you to write efficient, differentiable, and compilable code.
© 2025 ApX Machine Learning