Okay, let's put theory into practice. You've learned that jax.custom_vjp
allows you to define your own vector-Jacobian product rule for a function. This is useful in several scenarios:
In this section, we'll implement a custom VJP for a common function, focusing on the mechanics and verification.
Consider the Softplus function, defined as:
softplus(x)=log(1+ex)This function is often used as a smooth approximation of the ReLU activation function.
Let's first write a standard JAX implementation:
import jax
import jax.numpy as jnp
import numpy as np # Used for comparison/testing
# Naive implementation - potentially unstable for large x
def softplus_naive(x):
return jnp.log(1 + jnp.exp(x))
# Example usage
x_val = 10.0
print(f"softplus_naive({x_val}) = {softplus_naive(x_val)}")
# Calculate gradient using JAX's autodiff
grad_softplus_naive = jax.grad(softplus_naive)
print(f"Gradient at {x_val} (JAX auto): {grad_softplus_naive(x_val)}")
For very large positive values of x, ex can overflow standard floating-point representations (like float32
). While JAX and XLA often handle such cases intelligently, let's assume we want to explicitly provide the gradient rule for pedagogical purposes or perhaps for a scenario where the automatic gradient does cause issues.
The analytical derivative of softplus(x)
is:
The sigmoid function is numerically stable. We can use this analytical form in our custom VJP.
jax.custom_vjp
To define a custom VJP, we need three components:
@jax.custom_vjp
.fwd
) that computes the original output and saves any intermediate values (residuals
) needed for the backward pass. It must return (output, residuals)
.bwd
) that takes the residuals
and the upstream gradient (g
, representing dL/dy where y is the output of our function) and computes the downstream gradient with respect to the original inputs (dL/dx=(dL/dy)×(dy/dx)). It must return a tuple of gradients, one for each positional argument of the original function.Let's implement this for softplus
:
import jax
import jax.numpy as jnp
import numpy as np
# 1. Define the function and decorate it
@jax.custom_vjp
def softplus_custom(x):
"""Softplus implementation intended for custom VJP."""
# For demonstration, we use the potentially unstable forward pass here.
# In a real scenario, you might use a more stable forward pass too.
return jnp.log(1 + jnp.exp(x))
# 2. Define the forward pass function (_fwd)
# It takes the same arguments as the original function (x).
# It returns the output (y = softplus(x)) and residuals needed for backward pass.
# We need the original input 'x' to compute sigmoid(x) in the backward pass.
def softplus_fwd(x):
y = softplus_custom(x) # Compute the primal output using the decorated function
residuals = x # Save 'x' for the backward pass
return y, residuals
# 3. Define the backward pass function (_bwd)
# It takes the residuals saved by _fwd and the upstream gradient 'g'.
# It returns a tuple of gradients w.r.t the inputs of the original function.
# Here, the only input is 'x', so we return a tuple with one element.
def softplus_bwd(residuals, g):
x = residuals # Unpack the residuals
# Compute the gradient: g * sigmoid(x)
# Use jax.nn.sigmoid for numerical stability
grad_x = g * jax.nn.sigmoid(x)
return (grad_x,) # Return gradient w.r.t. 'x' as a tuple
# Associate the fwd and bwd functions with the custom_vjp function
softplus_custom.defvjp(softplus_fwd, softplus_bwd)
# --- Now let's test it ---
# Calculate gradient using the custom VJP implementation
grad_softplus_custom = jax.grad(softplus_custom)
# Test values
test_values = jnp.array([-10.0, 0.0, 10.0, 80.0, 100.0]) # Include large values
# Compare naive autodiff gradient with custom gradient
print("Input | Naive Grad (Autodiff) | Custom Grad (VJP)")
print("------|-----------------------|-------------------")
for x_val in test_values:
# Use a try-except block for the naive version which might overflow/warn
try:
naive_grad = jax.grad(softplus_naive)(x_val)
except OverflowError:
naive_grad = float('inf') # Or handle as appropriate
custom_grad = grad_softplus_custom(x_val)
print(f"{x_val:<5.1f} | {naive_grad:<21.8f} | {custom_grad:<19.8f}")
# JIT compilation works as expected
jit_grad_softplus_custom = jax.jit(jax.grad(softplus_custom))
print(f"\nJIT-compiled custom gradient at 10.0: {jit_grad_softplus_custom(10.0):.8f}")
When you run the code above, you should observe that the gradients calculated by jax.grad(softplus_naive)
and jax.grad(softplus_custom)
are identical (within floating-point precision).
jax.nn.sigmoid(x)
, which is numerically stable. While JAX's default autodiff for log(1 + exp(x))
also often yields a stable result for the gradient, defining it explicitly guarantees this behaviour. Note that the forward pass in softplus_custom
is still the naive version; a production scenario might implement a stable forward pass and provide a custom VJP if needed.softplus_fwd
to compute the original function's value and cache the input x
. The softplus_bwd
function then used this cached x
along with the incoming gradient g
to compute the final gradient g * sigmoid(x)
. The defvjp
call links these two functions to the @jax.custom_vjp
decorator.residuals
. Saving unnecessary large intermediate arrays increases memory consumption. Save only what is strictly needed for the backward pass. Sometimes, recomputing values in the backward pass can be a trade-off for lower memory usage.fwd
function receives all inputs, and the bwd
function must return a tuple of gradients with the same length as the number of positional inputs to the original function. For outputs, the bwd
function receives a g
that matches the structure of the primal output (or is None
for outputs not involved in the gradient path).jax.custom_jvp
: For forward-mode differentiation, the process is analogous using @jax.custom_jvp
, defjvp
, and defining a function that computes (y, y_dot)
given (x, x_dot)
.This practical exercise demonstrates the core mechanics of defining custom gradients in JAX. It's a powerful tool for handling specific numerical or performance requirements. Remember to thoroughly test your custom rules against JAX's automatic differentiation or numerical gradients where feasible.
Was this section helpful?
© 2025 ApX Machine Learning