While jax.grad
and jax.vjp
provide the foundation for reverse-mode automatic differentiation, essential for training most neural networks, there are scenarios where direct control over forward-mode differentiation (Jacobian-vector products) is needed. Forward-mode calculates how a change in the input, represented by a "tangent" vector, propagates forward through the function to affect the output. The jax.jvp
function computes this directly.
Sometimes, JAX's automatic derivation of the JVP rule might not be optimal or even possible. This can happen for several reasons:
sqrt(x)
or log(x)
).jnp.clip
).jax.pure_callback
), JAX has no way to see inside it to compute the derivative.In these situations, you can define your own JVP rule using the jax.custom_jvp
decorator. This allows you to tell JAX precisely how to compute the forward-mode derivative for your function.
To define a custom JVP, you decorate your function with @jax.custom_jvp
and then define a separate JVP rule function associated with it using the @[your_function_name].defjvp
decorator.
Let's illustrate with a simple example: implementing a numerically stable version of jnp.sqrt
with a custom JVP rule that handles the derivative at x=0.
import jax
import jax.numpy as jnp
@jax.custom_jvp
def stable_sqrt(x):
"""Computes sqrt(x) using a small epsilon for stability near zero."""
is_zero = (x == 0.0)
# Avoid division by zero in gradient calculation if x is exactly zero
safe_x = jnp.where(is_zero, 1.0, x)
return jnp.sqrt(safe_x)
# Define the custom JVP rule for stable_sqrt
@stable_sqrt.defjvp
def stable_sqrt_jvp(primals, tangents):
"""Custom JVP rule for stable_sqrt."""
x, = primals
t, = tangents # t is the tangent vector dx
# Primal computation (the forward pass result)
primal_out = stable_sqrt(x)
# Tangent computation (the derivative part)
# Derivative of sqrt(x) is 0.5 * x**(-0.5)
# Handle x=0 case explicitly to avoid NaN/inf
is_zero = (x == 0.0)
# Use a large finite number for the derivative at 0, or zero,
# depending on desired behavior. Here, let's use 0.
safe_x = jnp.where(is_zero, 1.0, x) # Avoid division by zero
deriv = 0.5 * jax.lax.rsqrt(safe_x)
tangent_out = jnp.where(is_zero, 0.0, deriv * t) # Set derivative to 0 at x=0
return primal_out, tangent_out
# Example usage
x_val = jnp.array(0.0)
t_val = jnp.array(1.0) # Tangent vector
# Calculate JVP using the custom rule
primal_output, tangent_output = jax.jvp(stable_sqrt, (x_val,), (t_val,))
print(f"Input x: {x_val}")
print(f"Input tangent t: {t_val}")
print(f"Primal output (stable_sqrt(x)): {primal_output}")
print(f"Tangent output (derivative at x * t): {tangent_output}")
x_val_pos = jnp.array(4.0)
primal_output_pos, tangent_output_pos = jax.jvp(stable_sqrt, (x_val_pos,), (t_val,))
print(f"\nInput x: {x_val_pos}")
print(f"Input tangent t: {t_val}")
print(f"Primal output (stable_sqrt(x)): {primal_output_pos}")
print(f"Tangent output (derivative at x * t): {tangent_output_pos}") # Should be 0.5 * 4**(-0.5) * 1 = 0.25
The function decorated with .defjvp
(here, stable_sqrt_jvp
) takes two arguments:
primals
: A tuple containing the primal inputs to the original function (stable_sqrt
). In our example, primals
is (x,)
.tangents
: A tuple containing the tangent values corresponding to the primal inputs. These represent the "direction" of differentiation. In our example, tangents
is (t,)
, where t
corresponds to x
.The JVP rule function must return a pair:
primal_out
: The result of computing the original function (the forward pass). It's often computed by calling the original function itself, as shown above: primal_out = stable_sqrt(x)
.tangent_out
: The result of the Jacobian-vector product. This represents how the output changes in response to the input changing in the direction specified by the tangents
. Mathematically, if f is the function and x is the input, it computes J(x)v, where J(x) is the Jacobian of f at x, and v is the input tangent vector. In the example, this is jnp.where(is_zero, 0.0, deriv * t)
.jax.custom_vjp
jax.custom_jvp
defines the forward-mode differentiation rule. Its counterpart, jax.custom_vjp
, defines the reverse-mode rule (vector-Jacobian product), which is the basis for jax.grad
.
jax.jvp
directly, or computing full Jacobians via jax.jacfwd
), defining only jax.custom_jvp
might suffice.jax.grad
), you typically need to define a custom VJP rule using jax.custom_vjp
.jax.custom_jvp
and jax.custom_vjp
if you need both forward and reverse mode capabilities for your function.Defining custom JVP rules is particularly useful when:
clip
).Like most JAX transformations, functions with custom JVP rules work correctly with jax.jit
, jax.vmap
, and other transformations. JAX will use your custom rule during tracing and compilation.
By mastering jax.custom_jvp
, you gain finer control over JAX's forward-mode automatic differentiation, enabling more robust and efficient implementations for specialized computational tasks. This complements the control offered by jax.custom_vjp
for reverse-mode differentiation.
© 2025 ApX Machine Learning