jax.jitjitjitjitgradjax.gradgrad of grad)jax.value_and_grad)vmapjax.vmapin_axes, out_axes)vmapvmap with jit and gradvmappmapjax.pmapin_axes, out_axes)lax.psum, lax.pmean, etc.)pmap with other Transformationspmapped Functionsin_axes, out_axes)While jax.vmap automatically handles adding a batch dimension, you often need more control. What if only some arguments represent batches of data? What if your data isn't conveniently batched along the first axis (axis 0)? To address these scenarios, the in_axes and out_axes arguments are used. They provide fine-grained control over how vmap transforms your function's inputs and outputs.
in_axesThe in_axes argument tells vmap which axis of each input argument should be mapped over (vectorized). It's typically provided as a tuple or list, where the length matches the number of positional arguments of the function being vectorized.
Each element in the in_axes tuple corresponds to an argument:
i means that the i-th axis of the corresponding argument is the batch dimension. vmap will effectively iterate over slices along this axis.None means that the corresponding argument should not be mapped over. Instead, the entire argument will be broadcast and reused across all vectorized calls. This is useful for parameters or constants that are shared across the batch.Let's look at an example. Suppose we have a function that adds a scalar value to each element of a vector:
import jax
import jax.numpy as jnp
def add_scalar(vector, scalar):
# Adds a scalar to every element of a vector
return vector + scalar
# Example data
vectors = jnp.arange(12).reshape(4, 3) # A batch of 4 vectors, each size 3
scalar_val = 100.0 # A single scalar value
If we want to apply add_scalar to each vector in our vectors batch, using the same scalar_val for every vector, we tell vmap to map over axis 0 of vectors but not to map over scalar_val:
# Map over axis 0 of the first argument (vectors)
# Broadcast the second argument (scalar_val)
vectorized_add_scalar = jax.vmap(add_scalar, in_axes=(0, None))
result = vectorized_add_scalar(vectors, scalar_val)
print("Input vectors (shape {}):\n{}".format(vectors.shape, vectors))
print("Input scalar:", scalar_val)
print("Result (shape {}):\n{}".format(result.shape, result))
Input vectors (shape (4, 3)):
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
Input scalar: 100.0
Result (shape (4, 3)):
[[100. 101. 102.]
[103. 104. 105.]
[106. 107. 108.]
[109. 110. 111.]]
As you can see, vmap applied add_scalar four times. In each application, it took one row (axis 0) from vectors and the entire scalar_val. The output result collects these individual results, stacked along axis 0, matching the input batch dimension.
What if we had a batch of scalars as well, and wanted to add the i-th scalar to the i-th vector? We would specify in_axes=(0, 0):
scalars = jnp.array([100., 200., 300., 400.]) # A batch of 4 scalars
# Map over axis 0 of the first argument (vectors)
# Map over axis 0 of the second argument (scalars)
vectorized_add_scalar_batch = jax.vmap(add_scalar, in_axes=(0, 0))
result_batch = vectorized_add_scalar_batch(vectors, scalars)
print("Input vectors (shape {}):\n{}".format(vectors.shape, vectors))
print("Input scalars (shape {}):\n{}".format(scalars.shape, scalars))
print("Result (shape {}):\n{}".format(result_batch.shape, result_batch))
Input vectors (shape (4, 3)):
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
Input scalars (shape (4,)):
[100. 200. 300. 400.]
Result (shape (4, 3)):
[[100. 101. 102.]
[203. 204. 205.]
[306. 307. 308.]
[409. 410. 411.]]
Notice that JAX automatically handled the broadcasting of the scalar scalars[i] across the elements of vectors[i] within each mapped function call.
You can also map over axes other than 0. For instance, in_axes=(1, None) would map over axis 1 of the first argument. This requires the shapes to align correctly. The size of the mapped axes across all mapped input arguments must be the same. JAX will raise an error if they don't match.
# Example: Map over axis 1 of vectors
vectors_T = vectors.T # Shape (3, 4)
# Map over axis 1 (columns) of vectors_T, broadcast scalar_val
vectorized_add_scalar_axis1 = jax.vmap(add_scalar, in_axes=(1, None))
result_axis1 = vectorized_add_scalar_axis1(vectors_T, scalar_val)
print("Input vectors_T (shape {}):\n{}".format(vectors_T.shape, vectors_T))
print("Input scalar:", scalar_val)
# The output batch dimension (size 4) will be axis 0 by default
print("Result (shape {}):\n{}".format(result_axis1.shape, result_axis1))
Input vectors_T (shape (3, 4)):
[[ 0 3 6 9]
[ 1 4 7 10]
[ 2 5 8 11]]
Input scalar: 100.0
Result (shape (4, 3)):
[[100. 101. 102.]
[103. 104. 105.]
[106. 107. 108.]
[109. 110. 111.]]
Even though we mapped over axis 1 of the input vectors_T, the resulting batch dimension in the output is axis 0 by default. We can control this using out_axes.
out_axesBy default, vmap stacks the results along axis 0. The out_axes argument allows you to specify which axis in the output should correspond to the mapped dimension.
Let's consider a function that processes a vector and returns a transformed vector:
def process_vector(v):
# Example: Double the vector
return v * 2
input_vectors = jnp.arange(12).reshape(4, 3) # Batch of 4 vectors
Using the default out_axes=0:
# Default: map input axis 0 to output axis 0
vectorized_process_default = jax.vmap(process_vector, in_axes=0, out_axes=0)
result_default = vectorized_process_default(input_vectors)
print("Input vectors (shape {}):\n{}".format(input_vectors.shape, input_vectors))
print("Result (out_axes=0, shape {}):\n{}".format(result_default.shape, result_default))
Input vectors (shape (4, 3)):
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
Result (out_axes=0, shape (4, 3)):
[[ 0 2 4]
[ 6 8 10]
[12 14 16]
[18 20 22]]
The output shape is (4, 3), where 4 is the batch dimension placed at axis 0.
Now, let's specify out_axes=1:
# Map input axis 0 to output axis 1
vectorized_process_out1 = jax.vmap(process_vector, in_axes=0, out_axes=1)
result_out1 = vectorized_process_out1(input_vectors)
print("Input vectors (shape {}):\n{}".format(input_vectors.shape, input_vectors))
print("Result (out_axes=1, shape {}):\n{}".format(result_out1.shape, result_out1))
Input vectors (shape (4, 3)):
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
Result (out_axes=1, shape (3, 4)):
[[ 0 6 12 18]
[ 2 8 14 20]
[ 4 10 16 22]]
The output shape is now (3, 4). The original vector dimension (size 3) is now axis 0, and the mapped batch dimension (size 4) has been placed at axis 1.
out_axes with Multiple Return Values (PyTrees)If your function returns multiple values (e.g., in a tuple or dictionary, which JAX calls PyTrees), out_axes can also be a PyTree structure matching the output. This allows you to specify different output axes for different return values.
def process_vector_pytree(v):
# Returns a dictionary with sum and doubled vector
return {'sum': v.sum(), 'doubled': v * 2}
# Map input axis 0. Place 'sum' batch axis at 0, 'doubled' batch axis at 1.
vectorized_pytree = jax.vmap(
process_vector_pytree,
in_axes=0,
out_axes={'sum': 0, 'doubled': 1}
)
result_pytree = vectorized_pytree(input_vectors)
print("Input vectors (shape {}):\n{}".format(input_vectors.shape, input_vectors))
print("Result PyTree:")
print(" Sum (shape {}):\n{}".format(result_pytree['sum'].shape, result_pytree['sum']))
print(" Doubled (shape {}):\n{}".format(result_pytree['doubled'].shape, result_pytree['doubled']))
Input vectors (shape (4, 3)):
[[ 0 1 2]
[ 3 4 5]
[ 6 7 8]
[ 9 10 11]]
Result PyTree:
Sum (shape (4,)):
[ 3. 12. 21. 30.]
Doubled (shape (3, 4)):
[[ 0 6 12 18]
[ 2 8 14 20]
[ 4 10 16 22]]
Here, the batch of sums has shape (4,) (batch axis 0), while the batch of doubled vectors has shape (3, 4) (batch axis 1), exactly as specified in out_axes.
in_axes and out_axesYou frequently use in_axes and out_axes together to precisely control the vectorization process. This combination provides the flexibility needed to adapt functions expecting single inputs to complex batching scenarios without rewriting the core logic or resorting to manual dimension shuffling. By understanding how to specify which input axes to map and where the resulting batch dimension should appear in the output, you can write cleaner and often more efficient JAX code for batched computations.
Was this section helpful?
jax.vmap, providing comprehensive details on its parameters and behavior, including the in_axes and out_axes arguments for fine-grained control over vectorization.vmap, JAX core contributors, 2024 - An accessible official tutorial that explains the principles of automatic vectorization in JAX, featuring practical examples that illustrate the effective use of in_axes and out_axes for various batching scenarios.vmap and the precise roles of in_axes and out_axes in managing data flow.© 2026 ApX Machine LearningEngineered with