While jax.grad and jax.vjp provide the foundation for reverse-mode automatic differentiation, essential for training most neural networks, there are scenarios where direct control over forward-mode differentiation (Jacobian-vector products) is needed. Forward-mode calculates how a change in the input, represented by a "tangent" vector, propagates forward through the function to affect the output. The jax.jvp function computes this directly.
Sometimes, JAX's automatic derivation of the JVP rule might not be optimal or even possible. This can happen for several reasons:
sqrt(x) or log(x)).jnp.clip).jax.pure_callback), JAX has no way to see inside it to compute the derivative.In these situations, you can define your own JVP rule using the jax.custom_jvp decorator. This allows you to tell JAX precisely how to compute the forward-mode derivative for your function.
To define a custom JVP, you decorate your function with @jax.custom_jvp and then define a separate JVP rule function associated with it using the @[your_function_name].defjvp decorator.
Let's illustrate with a simple example: implementing a numerically stable version of jnp.sqrt with a custom JVP rule that handles the derivative at x=0.
import jax
import jax.numpy as jnp
@jax.custom_jvp
def stable_sqrt(x):
"""Computes sqrt(x) using a small epsilon for stability near zero."""
is_zero = (x == 0.0)
# Avoid division by zero in gradient calculation if x is exactly zero
safe_x = jnp.where(is_zero, 1.0, x)
return jnp.sqrt(safe_x)
# Define the custom JVP rule for stable_sqrt
@stable_sqrt.defjvp
def stable_sqrt_jvp(primals, tangents):
"""Custom JVP rule for stable_sqrt."""
x, = primals
t, = tangents # t is the tangent vector dx
# Primal computation (the forward pass result)
primal_out = stable_sqrt(x)
# Tangent computation (the derivative part)
# Derivative of sqrt(x) is 0.5 * x**(-0.5)
# Handle x=0 case explicitly to avoid NaN/inf
is_zero = (x == 0.0)
# Use a large finite number for the derivative at 0, or zero,
# depending on desired behavior. Here, let's use 0.
safe_x = jnp.where(is_zero, 1.0, x) # Avoid division by zero
deriv = 0.5 * jax.lax.rsqrt(safe_x)
tangent_out = jnp.where(is_zero, 0.0, deriv * t) # Set derivative to 0 at x=0
return primal_out, tangent_out
# Example usage
x_val = jnp.array(0.0)
t_val = jnp.array(1.0) # Tangent vector
# Calculate JVP using the custom rule
primal_output, tangent_output = jax.jvp(stable_sqrt, (x_val,), (t_val,))
print(f"Input x: {x_val}")
print(f"Input tangent t: {t_val}")
print(f"Primal output (stable_sqrt(x)): {primal_output}")
print(f"Tangent output (derivative at x * t): {tangent_output}")
x_val_pos = jnp.array(4.0)
primal_output_pos, tangent_output_pos = jax.jvp(stable_sqrt, (x_val_pos,), (t_val,))
print(f"\nInput x: {x_val_pos}")
print(f"Input tangent t: {t_val}")
print(f"Primal output (stable_sqrt(x)): {primal_output_pos}")
print(f"Tangent output (derivative at x * t): {tangent_output_pos}") # Should be 0.5 * 4**(-0.5) * 1 = 0.25
The function decorated with .defjvp (here, stable_sqrt_jvp) takes two arguments:
primals: A tuple containing the primal inputs to the original function (stable_sqrt). In our example, primals is (x,).tangents: A tuple containing the tangent values corresponding to the primal inputs. These represent the "direction" of differentiation. In our example, tangents is (t,), where t corresponds to x.The JVP rule function must return a pair:
primal_out: The result of computing the original function (the forward pass). It's often computed by calling the original function itself, as shown above: primal_out = stable_sqrt(x).tangent_out: The result of the Jacobian-vector product. This represents how the output changes in response to the input changing in the direction specified by the tangents. Mathematically, if f is the function and x is the input, it computes J(x)v, where J(x) is the Jacobian of f at x, and v is the input tangent vector. In the example, this is jnp.where(is_zero, 0.0, deriv * t).jax.custom_vjpjax.custom_jvp defines the forward-mode differentiation rule. Its counterpart, jax.custom_vjp, defines the reverse-mode rule (vector-Jacobian product), which is the basis for jax.grad.
jax.jvp directly, or computing full Jacobians via jax.jacfwd), defining only jax.custom_jvp might suffice.jax.grad), you typically need to define a custom VJP rule using jax.custom_vjp.jax.custom_jvp and jax.custom_vjp if you need both forward and reverse mode capabilities for your function.Defining custom JVP rules is particularly useful when:
clip).Like most JAX transformations, functions with custom JVP rules work correctly with jax.jit, jax.vmap, and other transformations. JAX will use your custom rule during tracing and compilation.
By mastering jax.custom_jvp, you gain finer control over JAX's forward-mode automatic differentiation, enabling more efficient implementations for specialized computational tasks. This complements the control offered by jax.custom_vjp for reverse-mode differentiation.
Was this section helpful?
jax.custom_jvp, JAX developers, 2024 - Official documentation for defining custom Jacobian-vector products (forward-mode differentiation rules) in JAX.© 2026 ApX Machine LearningEngineered with