Okay, let's put theory into practice. You've learned that jax.custom_vjp
allows you to define your own vector-Jacobian product rule for a function. This is useful in several scenarios:
In this section, we'll implement a custom VJP for a common function, focusing on the mechanics and verification.
Consider the Softplus function, defined as:
softplus(x)=log(1+ex)This function is often used as a smooth approximation of the ReLU activation function.
Let's first write a standard JAX implementation:
import jax
import jax.numpy as jnp
import numpy as np # Used for comparison/testing
# Naive implementation - potentially unstable for large x
def softplus_naive(x):
return jnp.log(1 + jnp.exp(x))
# Example usage
x_val = 10.0
print(f"softplus_naive({x_val}) = {softplus_naive(x_val)}")
# Calculate gradient using JAX's autodiff
grad_softplus_naive = jax.grad(softplus_naive)
print(f"Gradient at {x_val} (JAX auto): {grad_softplus_naive(x_val)}")
For very large positive values of x, ex can overflow standard floating-point representations (like float32
). While JAX and XLA often handle such cases intelligently, let's assume we want to explicitly provide the gradient rule for pedagogical purposes or perhaps for a scenario where the automatic gradient does cause issues.
The analytical derivative of softplus(x)
is:
The sigmoid function is numerically stable. We can use this analytical form in our custom VJP.
jax.custom_vjp
To define a custom VJP, we need three components:
@jax.custom_vjp
.fwd
) that computes the original output and saves any intermediate values (residuals
) needed for the backward pass. It must return (output, residuals)
.bwd
) that takes the residuals
and the upstream gradient (g
, representing dL/dy where y is the output of our function) and computes the downstream gradient with respect to the original inputs (dL/dx=(dL/dy)×(dy/dx)). It must return a tuple of gradients, one for each positional argument of the original function.Let's implement this for softplus
:
import jax
import jax.numpy as jnp
import numpy as np
# 1. Define the function and decorate it
@jax.custom_vjp
def softplus_custom(x):
"""Softplus implementation intended for custom VJP."""
# For demonstration, we use the potentially unstable forward pass here.
# In a real scenario, you might use a more stable forward pass too.
return jnp.log(1 + jnp.exp(x))
# 2. Define the forward pass function (_fwd)
# It takes the same arguments as the original function (x).
# It returns the output (y = softplus(x)) and residuals needed for backward pass.
# We need the original input 'x' to compute sigmoid(x) in the backward pass.
def softplus_fwd(x):
y = softplus_custom(x) # Compute the primal output using the decorated function
residuals = x # Save 'x' for the backward pass
return y, residuals
# 3. Define the backward pass function (_bwd)
# It takes the residuals saved by _fwd and the upstream gradient 'g'.
# It returns a tuple of gradients w.r.t the inputs of the original function.
# Here, the only input is 'x', so we return a tuple with one element.
def softplus_bwd(residuals, g):
x = residuals # Unpack the residuals
# Compute the gradient: g * sigmoid(x)
# Use jax.nn.sigmoid for numerical stability
grad_x = g * jax.nn.sigmoid(x)
return (grad_x,) # Return gradient w.r.t. 'x' as a tuple
# Associate the fwd and bwd functions with the custom_vjp function
softplus_custom.defvjp(softplus_fwd, softplus_bwd)
# --- Now let's test it ---
# Calculate gradient using the custom VJP implementation
grad_softplus_custom = jax.grad(softplus_custom)
# Test values
test_values = jnp.array([-10.0, 0.0, 10.0, 80.0, 100.0]) # Include large values
# Compare naive autodiff gradient with custom gradient
print("Input | Naive Grad (Autodiff) | Custom Grad (VJP)")
print("------|-----------------------|-------------------")
for x_val in test_values:
# Use a try-except block for the naive version which might overflow/warn
try:
naive_grad = jax.grad(softplus_naive)(x_val)
except OverflowError:
naive_grad = float('inf') # Or handle as appropriate
custom_grad = grad_softplus_custom(x_val)
print(f"{x_val:<5.1f} | {naive_grad:<21.8f} | {custom_grad:<19.8f}")
# JIT compilation works as expected
jit_grad_softplus_custom = jax.jit(jax.grad(softplus_custom))
print(f"\nJIT-compiled custom gradient at 10.0: {jit_grad_softplus_custom(10.0):.8f}")
When you run the code above, you should observe that the gradients calculated by jax.grad(softplus_naive)
and jax.grad(softplus_custom)
are identical (within floating-point precision).
jax.nn.sigmoid(x)
, which is numerically stable. While JAX's default autodiff for log(1 + exp(x))
also often yields a stable result for the gradient, defining it explicitly guarantees this behaviour. Note that the forward pass in softplus_custom
is still the naive version; a production scenario might implement a stable forward pass and provide a custom VJP if needed.softplus_fwd
to compute the original function's value and cache the input x
. The softplus_bwd
function then used this cached x
along with the incoming gradient g
to compute the final gradient g * sigmoid(x)
. The defvjp
call links these two functions to the @jax.custom_vjp
decorator.residuals
. Saving unnecessary large intermediate arrays increases memory consumption. Save only what is strictly needed for the backward pass. Sometimes, recomputing values in the backward pass can be a trade-off for lower memory usage.fwd
function receives all inputs, and the bwd
function must return a tuple of gradients with the same length as the number of positional inputs to the original function. For outputs, the bwd
function receives a g
that matches the structure of the primal output (or is None
for outputs not involved in the gradient path).jax.custom_jvp
: For forward-mode differentiation, the process is analogous using @jax.custom_jvp
, defjvp
, and defining a function that computes (y, y_dot)
given (x, x_dot)
.This practical exercise demonstrates the core mechanics of defining custom gradients in JAX. It's a powerful tool for handling specific numerical or performance requirements that go beyond standard automatic differentiation. Remember to thoroughly test your custom rules against JAX's automatic differentiation or numerical gradients where feasible.
© 2025 ApX Machine Learning