While jax.grad
and the underlying jax.vjp
mechanism handle automatic differentiation for a vast range of standard JAX operations, there are situations where you need more control over how gradients are computed. JAX provides the jax.custom_vjp
decorator to let you define a bespoke vector-Jacobian product (VJP) rule for your Python function.
Why would you need a custom VJP? Here are a few common scenarios:
jax.lax.stop_gradient
is often used for simply blocking gradients).jax.custom_vjp
DecoratorThe @jax.custom_vjp
decorator allows you to associate a function with custom forward and backward pass implementations for reverse-mode automatic differentiation.
Let's break down how it works. Suppose you have a function f
that you want to customize:
import jax
import jax.numpy as jnp
@jax.custom_vjp
def f(x, y):
# Original implementation
result = x * jnp.sin(y)
return result
To make this work, you need to define two auxiliary functions:
f_fwd(x, y)
: This function defines the forward pass. It computes the original function's output (primals_out
) and any intermediate values (residuals
) needed for the backward pass. It must return (primals_out, residuals)
.f_bwd(residuals, cotangents_in)
: This function defines the backward pass (the VJP). It takes the residuals
saved by f_fwd
and the incoming cotangent vector (cotangents_in
, which represents the gradient of the final loss with respect to the output of f
). It must return a tuple containing the gradients with respect to the primal inputs of f
(in this case, (dL/dx, dL/dy)
). The number of elements in the returned tuple must match the number of primal inputs to f
.You then associate these functions with the original f
using f.defvjp(f_fwd, f_bwd)
.
Here's the complete structure:
import jax
import jax.numpy as jnp
@jax.custom_vjp
def f(x, y):
# This original implementation is often simple,
# representing the mathematical function.
# It might not even be called directly by JAX's AD system
# if f_fwd computes the result efficiently.
return x * jnp.sin(y)
# Forward pass: computes result and saves necessary values
def f_fwd(x, y):
# Calculate the primal output
primals_out = f(x, y) # Or recalculate: x * jnp.sin(y)
# Save values needed for backward pass
# In this case, we need x, y, and sin(y) to compute gradients
residuals = (x, y, jnp.sin(y), jnp.cos(y)) # Could compute cos(y) here or in bwd
return primals_out, residuals
# Backward pass: computes gradients w.r.t. inputs
def f_bwd(residuals, cotangents_in):
# Unpack residuals
x, y, sin_y, cos_y = residuals
# g represents the incoming gradient dL/d(f(x,y))
g = cotangents_in
# Compute gradient w.r.t. x: dL/dx = dL/df * df/dx
# df/dx = sin(y)
grad_x = g * sin_y
# Compute gradient w.r.t. y: dL/dy = dL/df * df/dy
# df/dy = x * cos(y)
grad_y = g * x * cos_y
# Return gradients corresponding to inputs (x, y)
return (grad_x, grad_y)
# Register the forward and backward functions with f
f.defvjp(f_fwd, f_bwd)
# Example usage:
key = jax.random.PRNGKey(0)
x_val = jnp.float32(2.0)
y_val = jnp.pi / 2.0
# Calculate the gradient of a function using f
value_and_grad_fn = jax.value_and_grad(lambda x, y: f(x, y)**2, argnums=(0, 1))
value, grads = value_and_grad_fn(x_val, y_val)
print(f"Function value: {value}")
print(f"Gradients (dL/dx, dL/dy): {grads}")
# Verify against automatic differentiation (if f's original impl is differentiable)
value_and_grad_fn_auto = jax.value_and_grad(lambda x, y: (x * jnp.sin(y))**2, argnums=(0, 1))
value_auto, grads_auto = value_and_grad_fn_auto(x_val, y_val)
print(f"Auto Function value: {value_auto}")
print(f"Auto Gradients (dL/dx, dL/dy): {grads_auto}")
When jax.grad
(or jax.vjp
) encounters the call to f
, it doesn't differentiate the original Python code for f
. Instead, it executes f_fwd
during the forward pass, storing the residuals
. During the backward pass, it retrieves the residuals
and calls f_bwd
with the incoming cotangents_in
to get the gradients with respect to x
and y
.
log(1 + exp(x))
A classic example where custom VJPs are useful is the function f(x)=log(1+ex), often called softplus
. For large positive x, ex can overflow standard floating-point representations. Even if ex doesn't overflow, calculating the gradient using automatic differentiation involves computing ex/(1+ex), which approaches 1. This can be rewritten as 1/(1+e−x), which is the sigmoid function σ(x). This form is numerically stable for large positive x.
Let's implement this using jax.custom_vjp
:
import jax
import jax.numpy as jnp
import numpy as np # For comparison
@jax.custom_vjp
def stable_log1pexp(x):
# Primal function implementation
return jnp.log1p(jnp.exp(x))
def stable_log1pexp_fwd(x):
# Forward pass: Calculate output and save x for backward pass
exp_x = jnp.exp(x)
output = jnp.log1p(exp_x)
# We only need x to compute the stable gradient (sigmoid(x))
residuals = (x,)
return output, residuals
def stable_log1pexp_bwd(residuals, cotangent_in):
# Backward pass: Compute gradient using the stable formula
(x,) = residuals
# Gradient is sigmoid(x) * cotangent_in
# sigmoid(x) = 1 / (1 + exp(-x))
grad_x = cotangent_in * (1. / (1. + jnp.exp(-x)))
# Return gradient corresponding to the input x (as a tuple)
return (grad_x,)
# Register the custom VJP rules
stable_log1pexp.defvjp(stable_log1pexp_fwd, stable_log1pexp_bwd)
# --- Comparison ---
# Function using standard JAX operations
def unstable_log1pexp(x):
return jnp.log(1.0 + jnp.exp(x))
# Gradient functions
grad_stable = jax.grad(stable_log1pexp)
grad_unstable = jax.grad(unstable_log1pexp)
# Test values
x_small = jnp.float32(1.0)
x_large = jnp.float32(100.0) # exp(100) overflows float32
x_problematic = jnp.float32(35.0) # exp(35) is large but might not overflow float32, but 1+exp(x) might be inaccurate
print(f"--- Input: {x_small} ---")
print(f"Stable grad: {grad_stable(x_small)}")
print(f"Unstable grad: {grad_unstable(x_small)}")
print(f"\n--- Input: {x_problematic} ---")
print(f"Stable grad: {grad_stable(x_problematic)}")
print(f"Unstable grad: {grad_unstable(x_problematic)}") # Might be less accurate
print(f"\n--- Input: {x_large} ---")
print(f"Stable grad: {grad_stable(x_large)}")
# print(f"Unstable grad: {grad_unstable(x_large)}") # This will likely produce NaN due to overflow
# Note: For very large inputs, the unstable version overflows.
# The stable version correctly computes the gradient, which approaches 1.
try:
unstable_result = grad_unstable(x_large)
print(f"Unstable grad: {unstable_result}")
except FloatingPointError:
print("Unstable grad: NaN (Overflow expected)")
# JAX actually has a built-in numerically stable version:
grad_jax_builtin = jax.grad(jax.nn.softplus)
print(f"\n--- JAX Built-in softplus grad ({x_large}) ---")
print(f"Built-in grad: {grad_jax_builtin(x_large)}")
In this example, stable_log1pexp_fwd
calculates the result and saves the input x
. The stable_log1pexp_bwd
function uses the saved x
and the incoming cotangent_in
(which is 1.0 when called directly by jax.grad
) to compute the gradient using the numerically stable sigmoid formula. The comparison shows that for large inputs, the custom VJP remains stable while the automatically derived gradient can suffer from overflow or precision issues.
f_bwd
must match the true gradient of the original function f
. Thoroughly test your implementation against numerical differentiation or known results.residuals
. Saving large intermediate arrays can increase memory consumption. Sometimes, it's better to recompute values in f_bwd
if they are cheap to compute and saving them would consume significant memory.jax.custom_vjp
can generally be composed with other JAX transformations like jit
, vmap
, and even nested differentiation, provided the f_fwd
and f_bwd
functions themselves are JAX-traceable and follow standard JAX function rules.f
takes arguments that should not be differentiated with respect to (e.g., static parameters, non-floating point types), the f_bwd
function should return None
in the corresponding positions of the output tuple.jax.custom_vjp
is a powerful tool for gaining fine-grained control over reverse-mode differentiation in JAX. While not needed for everyday use, it becomes invaluable when dealing with numerical stability, performance bottlenecks, or specialized gradient calculations in advanced models and algorithms. Remember to also consider jax.custom_jvp
if you need control over forward-mode differentiation or if the JVP rule is simpler to define for your function.
© 2025 ApX Machine Learning