You've successfully defined a custom primitive, given it an abstract evaluation rule for shape and dtype inference, and implemented backend-specific lowering rules. However, for this primitive to be truly useful within the JAX ecosystem, it needs to integrate with JAX's automatic differentiation system. Without differentiation rules, applying jax.grad
, jax.jvp
, or jax.vjp
to a function containing your primitive will result in an error, as JAX won't know how to propagate gradients through it.
This section explains how to define the necessary differentiation rules for your custom primitives, enabling them to participate fully in gradient-based optimization and other automatic differentiation tasks.
JAX's automatic differentiation machinery relies on knowing the differentiation rules for every primitive operation it encounters during tracing. For built-in primitives like jax.lax.add
or jax.lax.sin
, these rules are already defined internally. But for your custom primitive, JAX has no prior knowledge of its mathematical behavior concerning differentiation.
Therefore, you must explicitly provide these rules. JAX uses two fundamental modes of automatic differentiation:
primitive.def_jvp
.jax.grad
. You define this using primitive.def_vjp
.To make your primitive fully differentiable and usable with all of JAX's differentiation transformations, you typically need to define both JVP and VJP rules.
The JVP rule describes how a perturbation (tangent) in the input of your primitive affects its output. Mathematically, if your primitive computes y=f(x), the JVP rule computes Jv, where J is the Jacobian of f at x, and v is the input tangent vector.
You define the JVP rule using the def_jvp
method of your primitive object. The function you provide to def_jvp
should have a specific signature:
def primitive_jvp(primals, tangents):
# primals: A tuple of the primal inputs to the primitive.
# tangents: A tuple of tangent values corresponding to the primals.
# Tangents have the same shape and dtype as the corresponding primal.
# A tangent value might be jax.ad.Zero if the corresponding
# primal is considered constant with respect to differentiation.
# 1. Compute the primal outputs (same as the abstract evaluation or implementation)
primal_outputs = ... # Compute f(primals)
# 2. Compute the tangent outputs based on primals and input tangents
tangent_outputs = ... # Compute J @ tangents
return primal_outputs, tangent_outputs
The JVP rule function receives the primal inputs and their corresponding tangents. It must return a pair: the primal outputs (which should match the result of the primitive's standard evaluation) and the corresponding output tangents.
Example: JVP for a Custom Scaling Primitive
Let's revisit a simple custom primitive custom_scale_p
that scales an input x
by a factor
: y=x×factor.
import jax
import jax.numpy as jnp
from jax.core import Primitive
from jax.interpreters import ad
from jax.interpreters import mlir
# Assume custom_scale_p is already defined with abstract eval and lowering
# For example:
custom_scale_p = Primitive('custom_scale')
@custom_scale_p.def_impl
def _custom_scale_impl(x, factor):
# Example CPU implementation
return x * factor
@custom_scale_p.def_abstract_eval
def _custom_scale_abstract_eval(x_abs, factor_abs):
# Example abstract evaluation
assert jax.core.get_aval(factor_abs) == jax.core.ShapedArray((), factor_abs.dtype)
return jax.core.ShapedArray(x_abs.shape, x_abs.dtype)
# Assume MLIR lowering is also defined...
# Now, define the JVP rule
@custom_scale_p.def_jvp
def _custom_scale_jvp(primals, tangents):
x, factor = primals
x_dot, factor_dot = tangents # Input tangents
# Primal output computation (could also call the impl)
y = x * factor
# Calculate output tangent y_dot
# Use the product rule: d(x*factor)/dt = (dx/dt)*factor + x*(dfactor/dt)
# Handle Zero tangents correctly
y_dot = ad.Zero.zero_if_zero(factor_dot)
if not isinstance(x_dot, ad.Zero):
y_dot += x_dot * factor
if not isinstance(factor_dot, ad.Zero):
y_dot += x * factor_dot
# Ensure output tangent has same structure as output primal
if isinstance(y_dot, ad.Zero) and y is not None:
# If y_dot is Zero, create a concrete zero tangent with the right shape/dtype
y_dot = jnp.zeros_like(y)
print(f"Custom JVP: x={x}, factor={factor}, x_dot={x_dot}, factor_dot={factor_dot}, y={y}, y_dot={y_dot}")
return y, y_dot
# Example usage with jax.jvp
x_val = jnp.array([1.0, 2.0, 3.0])
factor_val = 2.0
x_tangent = jnp.array([0.1, 0.2, 0.3])
factor_tangent = 0.5 # Tangent for the scalar factor
# Define a function that uses the primitive
def apply_scale(x, factor):
return custom_scale_p.bind(x, factor=factor) # Use bind
# Compute JVP
primal_out, tangent_out = jax.jvp(apply_scale, (x_val, factor_val), (x_tangent, factor_tangent))
print(f"Primal output: {primal_out}")
print(f"Tangent output: {tangent_out}")
# Expected output tangent:
# y_dot = x_dot * factor + x * factor_dot
# = [0.1, 0.2, 0.3] * 2.0 + [1.0, 2.0, 3.0] * 0.5
# = [0.2, 0.4, 0.6] + [0.5, 1.0, 1.5]
# = [0.7, 1.4, 2.1]
Notice how we handle ad.Zero
to avoid unnecessary computation if an input tangent is zero. The JVP rule correctly applies the product rule of differentiation.
The VJP rule is central to reverse-mode automatic differentiation, which is how jax.grad
works. It describes how a cotangent vector at the output of the primitive propagates backward to produce cotangents for the inputs. Mathematically, if y=f(x), the VJP rule computes vTJ, where vT is the output cotangent vector (a row vector) and J is the Jacobian.
Defining a VJP rule with primitive.def_vjp
is slightly more complex because reverse-mode requires information from the forward pass to compute the backward pass. The def_vjp
decorator expects a function that performs the forward pass and returns both the primal outputs and a residual value. This residual contains any intermediate values from the forward pass needed for the gradient calculation. def_vjp
also expects a second function (often defined locally within the first) that performs the backward pass.
The structure looks like this:
def primitive_vjp(primals):
# primals: A tuple of the primal inputs.
# 1. Compute the primal outputs
primal_outputs = ... # Compute f(primals)
# 2. Determine residuals needed for backward pass
residuals = ... # e.g., primal inputs, intermediate values
# 3. Define the backward pass function
def backward_pass(residuals, output_cotangents):
# residuals: The data saved from the forward pass.
# output_cotangents: The cotangent vector(s) corresponding to the primal_outputs.
# Compute input cotangents based on residuals and output_cotangents
input_cotangents = ... # Compute output_cotangents^T @ J
return input_cotangents # Must be a tuple matching the structure of primals
return primal_outputs, backward_pass
The outer function primitive_vjp
takes primal inputs, computes primal outputs, and packages necessary data into residuals
. It returns the primal outputs and the backward_pass
function. JAX then calls backward_pass
later with the residuals and the incoming output cotangents to get the input cotangents.
Example: VJP for the Custom Scaling Primitive
Continuing the custom_scale_p
example (y=x×factor):
# Define the VJP rule
@custom_scale_p.def_vjp
def _custom_scale_vjp(primals):
x, factor = primals
# Forward pass: Compute output and save inputs for backward pass
y = custom_scale_p.bind(x, factor=factor) # Use bind for actual computation
residuals = (x, factor) # Need x and factor for the backward pass
# Define the backward pass function
def backward_pass(residuals, y_bar):
# y_bar is the cotangent of the output y
x_res, factor_res = residuals # Unpack residuals
# Calculate input cotangents (gradients)
# Gradient w.r.t. x: dy/dx = factor => x_bar = y_bar * factor
x_bar = y_bar * factor_res
# Gradient w.r.t. factor: dy/dfactor = x => factor_bar = sum(y_bar * x)
# Need to sum if x is not a scalar, as factor is scalar.
factor_bar = jnp.sum(y_bar * x_res)
print(f"Custom VJP Backward: y_bar={y_bar}, x_res={x_res}, factor_res={factor_res}, x_bar={x_bar}, factor_bar={factor_bar}")
# Return tuple of cotangents matching the order of primals (x, factor)
return (x_bar, factor_bar)
return y, backward_pass
# Example usage with jax.grad
x_val = jnp.array([1.0, 2.0, 3.0])
factor_val = 2.0
# Define a function to differentiate
def loss_fn(x, factor):
y = apply_scale(x, factor) # Uses our primitive via apply_scale
return jnp.sum(y * y) # Example loss: sum of squares
# Compute gradients using jax.grad
# grad takes derivatives w.r.t. specified argnums (0 for x, 1 for factor)
grad_x = jax.grad(loss_fn, argnums=0)(x_val, factor_val)
grad_factor = jax.grad(loss_fn, argnums=1)(x_val, factor_val)
print(f"Gradient w.r.t. x: {grad_x}")
print(f"Gradient w.r.t. factor: {grad_factor}")
# Expected gradients:
# L = sum( (x*factor)^2 )
# dL/dy = 2*y = 2*x*factor
# dL/dx = dL/dy * dy/dx = (2*x*factor) * factor = 2*x*factor^2
# = 2 * [1, 2, 3] * 2^2 = [8, 16, 24]
# dL/dfactor = sum( dL/dy * dy/dfactor ) = sum( (2*x*factor) * x ) = sum( 2*x^2*factor )
# = 2 * factor * sum(x^2) = 2 * 2.0 * (1^2 + 2^2 + 3^2)
# = 4 * (1 + 4 + 9) = 4 * 14 = 56.0
In this VJP rule, the forward part computes the result y
and saves the original inputs x
and factor
as residuals. The backward_pass
uses these residuals along with the incoming output cotangent y_bar
(which represents ∂Loss/∂y) to compute the input cotangents x_bar
(∂Loss/∂x) and factor_bar
(∂Loss/∂factor) using the chain rule.
Defining custom differentiation rules can be error-prone. It's highly recommended to verify their correctness. Common techniques include:
jax.test_util.check_grads
which automates this comparison for VJPs (and implicitly JVPs via JVP-VJP consistency checks).jax.test_util.check_grads
also helps verify this relationship.# Example using jax.test_util for verification
from jax.test_util import check_grads
# Check gradients for apply_scale function which uses the primitive
# check_grads compares analytical gradients (from VJP) with numerical estimates
check_grads(apply_scale, (x_val, factor_val), order=2, modes=['fwd', 'rev'], eps=1e-3)
# order=2 checks both first and second derivatives (if applicable)
# modes=['fwd', 'rev'] checks both JVP and VJP consistency
print("Gradient checks passed!")
Running check_grads
provides confidence that your differentiation rules are implemented correctly.
Once you have defined the JVP (def_jvp
) and VJP (def_vjp
) rules for your custom primitive, it seamlessly integrates with JAX's automatic differentiation system. You can now apply jax.grad
, jax.jvp
, jax.vjp
, and even compose these transformations for higher-order derivatives on any JAX function that utilizes your primitive, just as you would with built-in operations. This completes the process of making your custom operation a fully integrated part of the JAX ecosystem.
© 2025 ApX Machine Learning