Forward-mode automatic differentiation computes the rate of change of a function's output with respect to a change in its input, propagated forward through the computation. The fundamental operation in forward-mode AD is the Jacobian-vector product (JVP). Where reverse-mode AD (used by jax.grad) is efficient for functions with many inputs and few outputs (like typical loss functions in ML), forward-mode shines when the number of inputs is relatively small compared to the number of outputs, or when you specifically need a directional derivative.
Consider a function f:Rn→Rm that maps an n-dimensional input vector x to an m-dimensional output vector y. The Jacobian matrix J(x) of f at a point x is an m×n matrix containing all the first-order partial derivatives:
J(x)=∂x∂f=∂x1∂f1∂x1∂f2⋮∂x1∂fm∂x2∂f1∂x2∂f2⋮∂x2∂fm⋯⋯⋱⋯∂xn∂f1∂xn∂f2⋮∂xn∂fmThe Jacobian J(x) represents the best linear approximation of the function f near the point x.
A Jacobian-vector product (JVP) computes the product of this Jacobian matrix J(x) with a "tangent" vector v∈Rn. This vector v represents a direction of perturbation in the input space. The JVP is given by:
JVP(x,v)=J(x)v=dαdf(x+αv)α=0The result, J(x)v, is an m-dimensional vector living in the output space. It tells you the rate of change of the function's output f(x) when the input x is perturbed infinitesimally in the direction specified by the tangent vector v. Crucially, forward-mode AD allows us to compute this product J(x)v without explicitly forming the potentially very large Jacobian matrix J(x). The computational cost is typically only a small constant factor more than evaluating the original function f(x).
jax.jvpJAX provides the jax.jvp transformation to compute Jacobian-vector products. Its signature is:
jax.jvp(fun, primals, tangents)
fun: The Python callable (function) to be differentiated.primals: A tuple or list containing the primal inputs at which to evaluate the function and its JVP. These are the points x in our notation. If fun takes multiple arguments, primals should contain values for all of them.tangents: A tuple or list containing the tangent vectors corresponding to the primals. These are the vectors v. The structure and types/shapes of tangents must match primals. You can provide tangent vectors for only a subset of the primals by passing non-differential values (like None or zero-valued arrays) for the others.jax.jvp returns a pair:
primal_out: The result of calling fun(*primals). This is f(x).tangent_out: The result of the JVP computation, J(x)v. This has the same structure and types/shapes as primal_out.Let's see a simple example:
import jax
import jax.numpy as jnp
# Define a function from R^2 -> R^2
def f(x):
return jnp.array([x[0]**2, x[0] * x[1]])
# Define the primal input point
x_primal = jnp.array([2.0, 3.0])
# Define the tangent vector (direction of perturbation)
v_tangent = jnp.array([1.0, 0.5])
# Compute the function output and the JVP
y_primal_out, tangent_out = jax.jvp(f, (x_primal,), (v_tangent,))
print(f"Primal input (x): {x_primal}")
print(f"Tangent vector (v): {v_tangent}")
print(f"Primal output f(x): {y_primal_out}")
print(f"Tangent output (J(x)v): {tangent_out}")
# Let's manually verify the Jacobian and JVP for this case:
# f(x) = [x_0^2, x_0 * x_1]
# J(x) = [[df1/dx0, df1/dx1], [df2/dx0, df2/dx1]]
# J(x) = [[2*x0, 0], [x1, x0]]
# At x = [2.0, 3.0], J(x) = [[4.0, 0.0], [3.0, 2.0]]
# J(x)v = [[4.0, 0.0], [3.0, 2.0]] @ [1.0, 0.5]
# = [4.0*1.0 + 0.0*0.5, 3.0*1.0 + 2.0*0.5]
# = [4.0, 3.0 + 1.0]
# = [4.0, 4.0]
# This matches the tangent_out from jax.jvp!
Output:
Primal input (x): [2. 3.]
Tangent vector (v): [1. 0.5]
Primal output f(x): [4. 6.]
Tangent output (J(x)v): [4. 4.]
As expected, jax.jvp computed both the original function's output f(x) and the directional derivative J(x)v.
While jax.grad (based on reverse-mode VJPs) is the workhorse for training most neural networks, JVPs computed via jax.jvp have their own strengths:
[1.0, 0.0, ..., 0.0]), the resulting JVP J(x)v gives you exactly the first column of the Jacobian matrix J(x). Repeating this for all one-hot basis vectors allows computing the full Jacobian, although using jax.jacfwd (which essentially uses vmap over jax.jvp) is often a more direct way to achieve this for dense Jacobians.Like other JAX transformations, jax.jvp integrates smoothly with PyTrees (nested lists, tuples, dicts) for both primals and tangents. The structure of the tangents argument must mirror the structure of the primals argument.
import jax
import jax.numpy as jnp
def predict(params, inputs):
# A simple linear model
return jnp.dot(inputs, params['w']) + params['b']
params = {
'w': jnp.array([[1.0, 2.0], [3.0, 4.0]]), # Shape (2, 2)
'b': jnp.array([0.1, -0.1]) # Shape (2,)
}
inputs = jnp.array([10.0, 20.0]) # Shape (2,)
# Define tangents matching the structure of params
# Perturb only weights 'w', keep bias 'b' constant (tangent is zero-like)
tangents = {
'w': jnp.ones_like(params['w']),
'b': jnp.zeros_like(params['b'])
}
# Compute JVP w.r.t params. Note inputs is treated as constant (no tangent provided).
# jax.jvp expects tuples for primals and tangents
primal_out, tangent_out = jax.jvp(predict, (params, inputs), (tangents, jax.lax.stop_gradient(inputs))) # or (tangents, None)
print(f"Inputs: {inputs}")
print(f"Params: {params}")
print(f"Tangents (perturbation): {tangents}")
print(f"Primal output (prediction): {primal_out}")
print(f"Tangent output (change in prediction due to perturbation): {tangent_out}")
# Manual check:
# Output = inputs @ w + b
# d(Output) / d(w_ij) * tangent_w_ij summed over i,j
# d(Output) / d(w) contribution: inputs @ tangent['w']
# d(Output) / d(b) contribution: tangent['b']
# Expected tangent_out = inputs @ tangent['w'] + tangent['b']
# = [10., 20.] @ [[1., 1.], [1., 1.]] + [0., 0.]
# = [10*1+20*1, 10*1+20*1] + [0., 0.]
# = [30., 30.]
Output:
Inputs: [10. 20.]
Params: {'w': Array([[1., 2.],
[3., 4.]], dtype=float32), 'b': Array([ 0.1, -0.1], dtype=float32)}
Tangents (perturbation): {'w': Array([[1., 1.],
[1., 1.]], dtype=float32), 'b': Array([0., 0.], dtype=float32)}
Primal output (prediction): [ 70.1 100.9]
Tangent output (change in prediction due to perturbation): [30. 30.]
Here, we computed how the output of predict changes when the parameters params are perturbed in the direction specified by tangents. Note that we passed jax.lax.stop_gradient(inputs) as the tangent for inputs to indicate we are not differentiating with respect to it (passing None would also work).
jax.jvp can also be composed with other transformations like jax.jit for performance or jax.vmap to compute JVPs with multiple different tangent vectors simultaneously.
Understanding jax.jvp provides insight into forward-mode automatic differentiation and equips you with a tool for calculating directional derivatives efficiently, complementing the reverse-mode capabilities offered by jax.grad and jax.vjp.
Was this section helpful?
jax.jvp, JAX core developers, 2024 - The official documentation for the jax.jvp function, providing API details, usage examples, and its role in the JAX framework.© 2026 ApX Machine LearningEngineered with