Often, the functions you want to vectorize operate on more than one input. For instance, you might need to compute the dot product between corresponding vectors in two different batches, or perhaps multiply every vector in a batch by the same matrix. jax.vmap
provides fine-grained control over how batching is applied to each argument through its in_axes
parameter.
in_axes
Argument: Specifying What to MapBy default, vmap
assumes you want to map over the first axis (axis 0) of the first argument passed to the function. If your function takes multiple arguments, or if you want to map over an axis other than 0, you need to use in_axes
.
The in_axes
argument takes a tuple (or list) whose length matches the number of positional arguments of the function being vectorized. Each element in the in_axes
tuple specifies how vmap
should handle the corresponding positional argument:
i
: This tells vmap
to map over axis i
of the corresponding argument. The size of this axis becomes the batch dimension. If multiple arguments have integer in_axes
, the sizes of their specified axes must match.None
: This tells vmap
not to map over the corresponding argument. Instead, the argument is treated as a constant and broadcast across the mapped dimensions of the other arguments.Let's explore common scenarios.
Imagine a function that scales a vector by a scalar factor:
import jax
import jax.numpy as jnp
def scale_vector(vector, scalar):
"""Scales a vector by a scalar."""
return vector * scalar
# Example single inputs
vector = jnp.arange(3.)
scalar = 2.0
print(f"Single scale: {scale_vector(vector, scalar)}")
# Expected Output: Single scale: [0. 2. 4.]
Now, suppose you have a batch of vectors and you want to scale each vector by the same scalar. You want to map over the first argument (the batch of vectors) but keep the second argument (the scalar) constant. This is achieved using in_axes=(0, None)
.
# Batch of vectors (3 vectors of size 3)
batch_of_vectors = jnp.arange(9.).reshape((3, 3))
single_scalar = 2.0
# Vectorize the function: map over axis 0 of vectors, broadcast scalar
vectorized_scale = jax.vmap(scale_vector, in_axes=(0, None))
# Apply the vectorized function
batched_result = vectorized_scale(batch_of_vectors, single_scalar)
print("\nBatch of vectors:")
print(batch_of_vectors)
print(f"\nSingle scalar: {single_scalar}")
print("\nResult after vmap(scale_vector, in_axes=(0, None)):")
print(batched_result)
# Expected Output:
# Batch of vectors:
# [[0. 1. 2.]
# [3. 4. 5.]
# [6. 7. 8.]]
#
# Single scalar: 2.0
#
# Result after vmap(scale_vector, in_axes=(0, None)):
# [[ 0. 2. 4.]
# [ 6. 8. 10.]
# [12. 14. 16.]]
Here, in_axes=(0, None)
instructs vmap
:
batch_of_vectors
), map over its axis 0. The size of this axis (3) determines the batch size.single_scalar
), do not map (None
). Treat it as fixed and broadcast it appropriately to match the batch dimension introduced by mapping the first argument.A very common use case is applying an operation element-wise between two batches of data. For example, adding corresponding vectors from two batches.
def add_vectors(vec1, vec2):
"""Adds two vectors."""
return vec1 + vec2
# Two batches of vectors (each has 3 vectors of size 3)
batch1 = jnp.arange(9.).reshape((3, 3))
batch2 = jnp.ones((3, 3)) * 10
print("Batch 1:")
print(batch1)
print("\nBatch 2:")
print(batch2)
# Vectorize: map over axis 0 of BOTH arguments
vectorized_add = jax.vmap(add_vectors, in_axes=(0, 0))
# Apply the vectorized function
batched_sum = vectorized_add(batch1, batch2)
print("\nResult after vmap(add_vectors, in_axes=(0, 0)):")
print(batched_sum)
# Expected Output:
# Batch 1:
# [[0. 1. 2.]
# [3. 4. 5.]
# [6. 7. 8.]]
#
# Batch 2:
# [[10. 10. 10.]
# [10. 10. 10.]
# [10. 10. 10.]]
#
# Result after vmap(add_vectors, in_axes=(0, 0)):
# [[10. 11. 12.]
# [13. 14. 15.]
# [16. 17. 18.]]
With in_axes=(0, 0)
, vmap
takes the k-th slice along axis 0 from batch1
and the k-th slice along axis 0 from batch2
and passes them as vec1
and vec2
to the original add_vectors
function for each k up to the batch size. The batch size is determined by the size of axis 0 in both inputs, which must be identical (in this case, 3).
in_axes
We can visualize how in_axes
directs the mapping. Consider vmap(f, in_axes=(0, None, 0))
applied to f(x, y, z)
with inputs xs
, y_val
, and zs
.
This diagram shows
vmap
taking slicesx₀, x₁, ...
fromxs
(axis 0) andz₀, z₁, ...
fromzs
(axis 0), while passing the samey_val
to each parallel invocation of functionf
. The results are then stacked to form the output batch.
While less common, you can map different arguments along different axes by providing different integers in in_axes
. For example, in_axes=(0, 1)
would map over axis 0 of the first argument and axis 1 of the second argument. This requires careful consideration of the function's logic and the shapes of the input arrays.
Consider a function applying a filter (1D vector) to each row of a matrix:
def apply_filter(row_vector, filter_vector):
"""Applies a filter (element-wise multiply). Assumes shapes match."""
return row_vector * filter_vector
# A matrix (e.g., 3 rows, 4 columns)
matrix = jnp.arange(12.).reshape((3, 4))
# A single filter vector (size 4)
filter_v = jnp.array([1., 10., 100., 1000.])
print("Matrix:")
print(matrix)
print(f"\nFilter: {filter_v}")
# Vectorize: map over axis 0 (rows) of matrix, broadcast the filter
# This is similar to Scenario 1
vectorized_apply_rows = jax.vmap(apply_filter, in_axes=(0, None))
result_rows = vectorized_apply_rows(matrix, filter_v)
print("\nApplying filter to each row (in_axes=(0, None)):")
print(result_rows)
# Expected Output:
# Applying filter to each row (in_axes=(0, None)):
# [[ 0. 10. 200. 3000.]
# [ 4. 50. 600. 7000.]
# [ 8. 90. 1000. 11000.]]
# Now, suppose we want to apply a filter to each *column*
# The function expects a vector, so we need to think about transposition
# Let's define filters for columns (size 3)
column_filters = jnp.array([1., 10., 100.])
# We want to map over axis 1 (columns) of the matrix, and axis 0 of the filters
# To make apply_filter work (element-wise product), the inputs need compatible shapes.
# vmap handles the slicing based on in_axes.
# `matrix` shape (3, 4), map axis 1 -> slices are shape (3,) - the columns
# `column_filters` shape (3,), map axis 0 -> slices are shape () - the scalars in the filter
# This isn't quite right for element-wise multiplication of column and filter.
# Let's redefine slightly: assume we have a batch of *column operations*,
# where each operation uses a different scalar multiplier for the whole column.
def scale_column(column_vector, scalar):
return column_vector * scalar
# Map over axis 1 (columns) of matrix, map over axis 0 of scalars
# Note: Number of columns (4) must match number of scalars (size of axis 0)
scalars = jnp.array([1., 10., 100., 1000.])
if matrix.shape[1] == scalars.shape[0]:
vectorized_apply_cols = jax.vmap(scale_column, in_axes=(1, 0))
result_cols = vectorized_apply_cols(matrix, scalars)
print("\nScaling each column by a different scalar (in_axes=(1, 0)):")
print(result_cols)
# Expected Output:
# Scaling each column by a different scalar (in_axes=(1, 0)):
# [[ 0. 10. 200. 3000.]
# [ 4. 50. 600. 7000.]
# [ 8. 90. 1000. 11000.]]
# Note: The result here looks similar to the row example because of the chosen values,
# but the mechanism is different: each *column* was scaled independently.
else:
print("\nSkipping column example due to shape mismatch.")
In this last example, in_axes=(1, 0)
tells vmap
:
matrix
argument, iterate through slices along axis 1 (columns). Each slice has shape (3,)
.scalars
argument, iterate through slices along axis 0. Each slice is a scalar.scale_column
.Mastering in_axes
is essential for leveraging vmap
effectively, allowing you to vectorize existing functions across batches of data with minimal code modification, even when dealing with multiple inputs that require different batching or broadcasting treatment.
© 2025 ApX Machine Learning