While jax.grad and the underlying jax.vjp mechanism handle automatic differentiation for a broad range of standard JAX operations, there are situations where you need more control over how gradients are computed. JAX provides the jax.custom_vjp decorator to let you define a bespoke vector-Jacobian product (VJP) rule for your Python function.Why would you need a custom VJP? Here are a few common scenarios:Numerical Stability: The standard sequence of operations in your function might lead to numerically unstable gradients when automatically differentiated, especially near boundary conditions or for large input values. You might know an alternative, mathematically equivalent formulation for the gradient that is more stable.Performance Optimization: The automatically derived VJP might involve more computation or memory access than necessary. If you know a more efficient analytical gradient, implementing it directly can speed up the backward pass.Mathematical Simplification: Sometimes, the analytical gradient is significantly simpler than what automatic differentiation would derive, potentially involving cancellations or identities that JAX might not automatically recognize.Implicit Functions: For functions defined implicitly (e.g., through an optimization problem or solver), you might know the gradient via the implicit function theorem, even if differentiating the solver steps directly is difficult or inefficient.Approximations: You might want to use an approximate gradient for specific purposes, like in certain optimization algorithms or for regularizing behavior (though jax.lax.stop_gradient is often used for simply blocking gradients).The jax.custom_vjp DecoratorThe @jax.custom_vjp decorator allows you to associate a function with custom forward and backward pass implementations for reverse-mode automatic differentiation.Let's break down how it works. Suppose you have a function f that you want to customize:import jax import jax.numpy as jnp @jax.custom_vjp def f(x, y): # Original implementation result = x * jnp.sin(y) return resultTo make this work, you need to define two auxiliary functions:f_fwd(x, y): This function defines the forward pass. It computes the original function's output (primals_out) and any intermediate values (residuals) needed for the backward pass. It must return (primals_out, residuals).f_bwd(residuals, cotangents_in): This function defines the backward pass (the VJP). It takes the residuals saved by f_fwd and the incoming cotangent vector (cotangents_in, which represents the gradient of the final loss with respect to the output of f). It must return a tuple containing the gradients with respect to the primal inputs of f (in this case, (dL/dx, dL/dy)). The number of elements in the returned tuple must match the number of primal inputs to f.You then associate these functions with the original f using f.defvjp(f_fwd, f_bwd).Here's the complete structure:import jax import jax.numpy as jnp @jax.custom_vjp def f(x, y): # This original implementation is often simple, # representing the mathematical function. # It might not even be called directly by JAX's AD system # if f_fwd computes the result efficiently. return x * jnp.sin(y) # Forward pass: computes result and saves necessary values def f_fwd(x, y): # Calculate the primal output primals_out = f(x, y) # Or recalculate: x * jnp.sin(y) # Save values needed for backward pass # In this case, we need x, y, and sin(y) to compute gradients residuals = (x, y, jnp.sin(y), jnp.cos(y)) # Could compute cos(y) here or in bwd return primals_out, residuals # Backward pass: computes gradients w.r.t. inputs def f_bwd(residuals, cotangents_in): # Unpack residuals x, y, sin_y, cos_y = residuals # g represents the incoming gradient dL/d(f(x,y)) g = cotangents_in # Compute gradient w.r.t. x: dL/dx = dL/df * df/dx # df/dx = sin(y) grad_x = g * sin_y # Compute gradient w.r.t. y: dL/dy = dL/df * df/dy # df/dy = x * cos(y) grad_y = g * x * cos_y # Return gradients corresponding to inputs (x, y) return (grad_x, grad_y) # Register the forward and backward functions with f f.defvjp(f_fwd, f_bwd) # Example usage: key = jax.random.PRNGKey(0) x_val = jnp.float32(2.0) y_val = jnp.pi / 2.0 # Calculate the gradient of a function using f value_and_grad_fn = jax.value_and_grad(lambda x, y: f(x, y)**2, argnums=(0, 1)) value, grads = value_and_grad_fn(x_val, y_val) print(f"Function value: {value}") print(f"Gradients (dL/dx, dL/dy): {grads}") # Verify against automatic differentiation (if f's original impl is differentiable) value_and_grad_fn_auto = jax.value_and_grad(lambda x, y: (x * jnp.sin(y))**2, argnums=(0, 1)) value_auto, grads_auto = value_and_grad_fn_auto(x_val, y_val) print(f"Auto Function value: {value_auto}") print(f"Auto Gradients (dL/dx, dL/dy): {grads_auto}")When jax.grad (or jax.vjp) encounters the call to f, it doesn't differentiate the original Python code for f. Instead, it executes f_fwd during the forward pass, storing the residuals. During the backward pass, it retrieves the residuals and calls f_bwd with the incoming cotangents_in to get the gradients with respect to x and y.Example: Numerically Stable log(1 + exp(x))A classic example where custom VJPs are useful is the function $f(x) = \log(1 + e^x)$, often called softplus. For large positive $x$, $e^x$ can overflow standard floating-point representations. Even if $e^x$ doesn't overflow, calculating the gradient using automatic differentiation involves computing $e^x / (1 + e^x)$, which approaches 1. This can be rewritten as $1 / (1 + e^{-x})$, which is the sigmoid function $\sigma(x)$. This form is numerically stable for large positive $x$.Let's implement this using jax.custom_vjp:import jax import jax.numpy as jnp import numpy as np # For comparison @jax.custom_vjp def stable_log1pexp(x): # Primal function implementation return jnp.log1p(jnp.exp(x)) def stable_log1pexp_fwd(x): # Forward pass: Calculate output and save x for backward pass exp_x = jnp.exp(x) output = jnp.log1p(exp_x) # We only need x to compute the stable gradient (sigmoid(x)) residuals = (x,) return output, residuals def stable_log1pexp_bwd(residuals, cotangent_in): # Backward pass: Compute gradient using the stable formula (x,) = residuals # Gradient is sigmoid(x) * cotangent_in # sigmoid(x) = 1 / (1 + exp(-x)) grad_x = cotangent_in * (1. / (1. + jnp.exp(-x))) # Return gradient corresponding to the input x (as a tuple) return (grad_x,) # Register the custom VJP rules stable_log1pexp.defvjp(stable_log1pexp_fwd, stable_log1pexp_bwd) # --- Comparison --- # Function using standard JAX operations def unstable_log1pexp(x): return jnp.log(1.0 + jnp.exp(x)) # Gradient functions grad_stable = jax.grad(stable_log1pexp) grad_unstable = jax.grad(unstable_log1pexp) # Test values x_small = jnp.float32(1.0) x_large = jnp.float32(100.0) # exp(100) overflows float32 x_problematic = jnp.float32(35.0) # exp(35) is large but might not overflow float32, but 1+exp(x) might be inaccurate print(f"--- Input: {x_small} ---") print(f"Stable grad: {grad_stable(x_small)}") print(f"Unstable grad: {grad_unstable(x_small)}") print(f"\n--- Input: {x_problematic} ---") print(f"Stable grad: {grad_stable(x_problematic)}") print(f"Unstable grad: {grad_unstable(x_problematic)}") # Might be less accurate print(f"\n--- Input: {x_large} ---") print(f"Stable grad: {grad_stable(x_large)}") # print(f"Unstable grad: {grad_unstable(x_large)}") # This will likely produce NaN due to overflow # Note: For very large inputs, the unstable version overflows. # The stable version correctly computes the gradient, which approaches 1. try: unstable_result = grad_unstable(x_large) print(f"Unstable grad: {unstable_result}") except FloatingPointError: print("Unstable grad: NaN (Overflow expected)") # JAX actually has a built-in numerically stable version: grad_jax_builtin = jax.grad(jax.nn.softplus) print(f"\n--- JAX Built-in softplus grad ({x_large}) ---") print(f"Built-in grad: {grad_jax_builtin(x_large)}")In this example, stable_log1pexp_fwd calculates the result and saves the input x. The stable_log1pexp_bwd function uses the saved x and the incoming cotangent_in (which is 1.0 when called directly by jax.grad) to compute the gradient using the numerically stable sigmoid formula. The comparison shows that for large inputs, the custom VJP remains stable while the automatically derived gradient can suffer from overflow or precision issues.Custom VJP ImplementationCorrectness: The most important aspect is ensuring your custom VJP is mathematically correct. The gradient computed by f_bwd must match the true gradient of the original function f. Thoroughly test your implementation against numerical differentiation or known results.Residuals: Only save the minimum necessary data in residuals. Saving large intermediate arrays can increase memory consumption. Sometimes, it's better to recompute values in f_bwd if they are cheap to compute and saving them would consume significant memory.Composition: Functions using jax.custom_vjp can generally be composed with other JAX transformations like jit, vmap, and even nested differentiation, provided the f_fwd and f_bwd functions themselves are JAX-traceable and follow standard JAX function rules.Non-Differentiable Inputs: If f takes arguments that should not be differentiated with respect to (e.g., static parameters, non-floating point types), the f_bwd function should return None in the corresponding positions of the output tuple.jax.custom_vjp is a powerful tool for gaining fine-grained control over reverse-mode differentiation in JAX. While not needed for everyday use, it becomes invaluable when dealing with numerical stability, performance bottlenecks, or specialized gradient calculations in advanced models and algorithms. Remember to also consider jax.custom_jvp if you need control over forward-mode differentiation or if the JVP rule is simpler to define for your function.