Forward-mode automatic differentiation computes the rate of change of a function's output with respect to a change in its input, propagated forward through the computation. The fundamental operation in forward-mode AD is the Jacobian-vector product (JVP). Where reverse-mode AD (used by jax.grad
) is efficient for functions with many inputs and few outputs (like typical loss functions in ML), forward-mode shines when the number of inputs is relatively small compared to the number of outputs, or when you specifically need a directional derivative.
Consider a function f:Rn→Rm that maps an n-dimensional input vector x to an m-dimensional output vector y. The Jacobian matrix J(x) of f at a point x is an m×n matrix containing all the first-order partial derivatives:
J(x)=∂x∂f=∂x1∂f1∂x1∂f2⋮∂x1∂fm∂x2∂f1∂x2∂f2⋮∂x2∂fm⋯⋯⋱⋯∂xn∂f1∂xn∂f2⋮∂xn∂fmThe Jacobian J(x) represents the best linear approximation of the function f near the point x.
A Jacobian-vector product (JVP) computes the product of this Jacobian matrix J(x) with a "tangent" vector v∈Rn. This vector v represents a direction of perturbation in the input space. The JVP is given by:
JVP(x,v)=J(x)v=dαdf(x+αv)α=0The result, J(x)v, is an m-dimensional vector living in the output space. It tells you the rate of change of the function's output f(x) when the input x is perturbed infinitesimally in the direction specified by the tangent vector v. Crucially, forward-mode AD allows us to compute this product J(x)v without explicitly forming the potentially very large Jacobian matrix J(x). The computational cost is typically only a small constant factor more than evaluating the original function f(x).
jax.jvp
JAX provides the jax.jvp
transformation to compute Jacobian-vector products. Its signature is:
jax.jvp(fun, primals, tangents)
fun
: The Python callable (function) to be differentiated.primals
: A tuple or list containing the primal inputs at which to evaluate the function and its JVP. These are the points x in our notation. If fun
takes multiple arguments, primals
should contain values for all of them.tangents
: A tuple or list containing the tangent vectors corresponding to the primals. These are the vectors v. The structure and types/shapes of tangents
must match primals
. You can provide tangent vectors for only a subset of the primals by passing non-differential values (like None
or zero-valued arrays) for the others.jax.jvp
returns a pair:
primal_out
: The result of calling fun(*primals)
. This is f(x).tangent_out
: The result of the JVP computation, J(x)v. This has the same structure and types/shapes as primal_out
.Let's see a simple example:
import jax
import jax.numpy as jnp
# Define a function from R^2 -> R^2
def f(x):
return jnp.array([x[0]**2, x[0] * x[1]])
# Define the primal input point
x_primal = jnp.array([2.0, 3.0])
# Define the tangent vector (direction of perturbation)
v_tangent = jnp.array([1.0, 0.5])
# Compute the function output and the JVP
y_primal_out, tangent_out = jax.jvp(f, (x_primal,), (v_tangent,))
print(f"Primal input (x): {x_primal}")
print(f"Tangent vector (v): {v_tangent}")
print(f"Primal output f(x): {y_primal_out}")
print(f"Tangent output (J(x)v): {tangent_out}")
# Let's manually verify the Jacobian and JVP for this case:
# f(x) = [x_0^2, x_0 * x_1]
# J(x) = [[df1/dx0, df1/dx1], [df2/dx0, df2/dx1]]
# J(x) = [[2*x0, 0], [x1, x0]]
# At x = [2.0, 3.0], J(x) = [[4.0, 0.0], [3.0, 2.0]]
# J(x)v = [[4.0, 0.0], [3.0, 2.0]] @ [1.0, 0.5]
# = [4.0*1.0 + 0.0*0.5, 3.0*1.0 + 2.0*0.5]
# = [4.0, 3.0 + 1.0]
# = [4.0, 4.0]
# This matches the tangent_out from jax.jvp!
Output:
Primal input (x): [2. 3.]
Tangent vector (v): [1. 0.5]
Primal output f(x): [4. 6.]
Tangent output (J(x)v): [4. 4.]
As expected, jax.jvp
computed both the original function's output f(x)
and the directional derivative J(x)v
.
While jax.grad
(based on reverse-mode VJPs) is the workhorse for training most neural networks, JVPs computed via jax.jvp
have their own strengths:
[1.0, 0.0, ..., 0.0]
), the resulting JVP J(x)v
gives you exactly the first column of the Jacobian matrix J(x). Repeating this for all one-hot basis vectors allows computing the full Jacobian, although using jax.jacfwd
(which essentially uses vmap
over jax.jvp
) is often a more direct way to achieve this for dense Jacobians.Like other JAX transformations, jax.jvp
integrates smoothly with PyTrees (nested lists, tuples, dicts) for both primals and tangents. The structure of the tangents
argument must mirror the structure of the primals
argument.
import jax
import jax.numpy as jnp
def predict(params, inputs):
# A simple linear model
return jnp.dot(inputs, params['w']) + params['b']
params = {
'w': jnp.array([[1.0, 2.0], [3.0, 4.0]]), # Shape (2, 2)
'b': jnp.array([0.1, -0.1]) # Shape (2,)
}
inputs = jnp.array([10.0, 20.0]) # Shape (2,)
# Define tangents matching the structure of params
# Perturb only weights 'w', keep bias 'b' constant (tangent is zero-like)
tangents = {
'w': jnp.ones_like(params['w']),
'b': jnp.zeros_like(params['b'])
}
# Compute JVP w.r.t params. Note inputs is treated as constant (no tangent provided).
# jax.jvp expects tuples for primals and tangents
primal_out, tangent_out = jax.jvp(predict, (params, inputs), (tangents, jax.lax.stop_gradient(inputs))) # or (tangents, None)
print(f"Inputs: {inputs}")
print(f"Params: {params}")
print(f"Tangents (perturbation): {tangents}")
print(f"Primal output (prediction): {primal_out}")
print(f"Tangent output (change in prediction due to perturbation): {tangent_out}")
# Manual check:
# Output = inputs @ w + b
# d(Output) / d(w_ij) * tangent_w_ij summed over i,j
# d(Output) / d(w) contribution: inputs @ tangent['w']
# d(Output) / d(b) contribution: tangent['b']
# Expected tangent_out = inputs @ tangent['w'] + tangent['b']
# = [10., 20.] @ [[1., 1.], [1., 1.]] + [0., 0.]
# = [10*1+20*1, 10*1+20*1] + [0., 0.]
# = [30., 30.]
Output:
Inputs: [10. 20.]
Params: {'w': Array([[1., 2.],
[3., 4.]], dtype=float32), 'b': Array([ 0.1, -0.1], dtype=float32)}
Tangents (perturbation): {'w': Array([[1., 1.],
[1., 1.]], dtype=float32), 'b': Array([0., 0.], dtype=float32)}
Primal output (prediction): [ 70.1 100.9]
Tangent output (change in prediction due to perturbation): [30. 30.]
Here, we computed how the output of predict
changes when the parameters params
are perturbed in the direction specified by tangents
. Note that we passed jax.lax.stop_gradient(inputs)
as the tangent for inputs
to indicate we are not differentiating with respect to it (passing None
would also work).
jax.jvp
can also be composed with other transformations like jax.jit
for performance or jax.vmap
to compute JVPs with multiple different tangent vectors simultaneously.
Understanding jax.jvp
provides insight into forward-mode automatic differentiation and equips you with a tool for calculating directional derivatives efficiently, complementing the reverse-mode capabilities offered by jax.grad
and jax.vjp
.
© 2025 ApX Machine Learning