Now that we've explored the concept of vectorization and the basic usage of jax.vmap
, let's put this knowledge into practice. This hands-on section will guide you through several examples, demonstrating how to apply vmap
in different scenarios, from simple function vectorization to combining it with other JAX transformations.
First, ensure you have JAX installed and imported:
import jax
import jax.numpy as jnp
import timeit
# Check available devices (CPU/GPU/TPU)
print(f"JAX devices: {jax.devices()}")
Let's start with a simple function designed to operate on a single input, say, a scalar value.
# Define a function that operates on a single scalar
def process_scalar(x):
return x * 2.0 + 1.0
# Test with a single scalar input
scalar_input = 5.0
scalar_output = process_scalar(scalar_input)
print(f"Input: {scalar_input}, Output: {scalar_output}")
# Now, let's create a batch of inputs
batch_inputs = jnp.arange(1.0, 6.0) # Array([1., 2., 3., 4., 5.])
print(f"Batch Inputs:\n{batch_inputs}")
Without vmap
, you might manually loop through the batch:
# Manual loop (less efficient)
manual_outputs = []
for item in batch_inputs:
manual_outputs.append(process_scalar(item))
manual_outputs = jnp.array(manual_outputs)
print(f"Manual Loop Outputs:\n{manual_outputs}")
This works, but Python loops are often slow, especially for large batches or complex functions running on accelerators. Now, let's use vmap
.
# Use vmap to create a vectorized version of the function
vectorized_process = jax.vmap(process_scalar)
# Apply the vectorized function to the batch
vmap_outputs = vectorized_process(batch_inputs)
print(f"vmap Outputs:\n{vmap_outputs}")
# Verify the shapes
print(f"Batch Input Shape: {batch_inputs.shape}")
print(f"vmap Output Shape: {vmap_outputs.shape}")
Notice how vmap
automatically handled the batch dimension. We wrote process_scalar
to work on a single value, but vmap(process_scalar)
works seamlessly on an array of values. By default, vmap
assumes the function should be mapped over the first axis (axis 0) of the input(s) and produce outputs where the first axis corresponds to the mapped dimension.
in_axes
What if your function takes multiple arguments, and you only want to vectorize over some of them? This is where the in_axes
argument comes in. in_axes
is a tuple specifying which axis of each input argument should be mapped over. A value of None
means the corresponding argument should be broadcast, not mapped.
Let's define a function that scales and shifts a vector:
# Function operating on a vector, a scalar scale, and a scalar shift
def scale_and_shift(vector, scale, shift):
return vector * scale + shift
# Example single inputs
single_vector = jnp.array([1.0, 2.0, 3.0])
single_scale = 2.0
single_shift = 10.0
# Apply the function to single inputs
single_output = scale_and_shift(single_vector, single_scale, single_shift)
print(f"Single Output:\n{single_output}")
# Now, create batches
batch_vectors = jnp.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]]) # Shape (3, 3)
batch_scales = jnp.array([0.5, 1.0, 1.5]) # Shape (3,)
# We want to use the *same* shift for all items in the batch
fixed_shift = 100.0
We want to apply scale_and_shift
such that:
batch_vectors
is processed with the first scale in batch_scales
.fixed_shift
.We specify this using in_axes=(0, 0, None)
. This tells vmap
:
batch_vectors
).batch_scales
).fixed_shift
); broadcast it instead.# Vectorize using in_axes
# Map over axis 0 for 'vector' and 'scale', broadcast 'shift'
vectorized_scale_shift = jax.vmap(scale_and_shift, in_axes=(0, 0, None))
# Apply the vectorized function
batch_outputs = vectorized_scale_shift(batch_vectors, batch_scales, fixed_shift)
print(f"\nBatch Vectors Shape: {batch_vectors.shape}")
print(f"Batch Scales Shape: {batch_scales.shape}")
# fixed_shift is a scalar
print(f"Batch Outputs:\n{batch_outputs}")
print(f"Batch Output Shape: {batch_outputs.shape}")
The output has shape (3, 3)
, where the first dimension 3
is the batch dimension introduced by vmap
, and the second dimension 3
comes from the original shape of the vector processed by scale_and_shift
.
You can also specify different axes. For example, if your batch dimension was the last axis of an input array, you would use in_axes=(..., -1, ...)
.
vmap
with jit
and grad
One of the significant advantages of JAX is the composability of its transformations. Let's see how vmap
works with jit
(for compilation) and grad
(for differentiation).
Consider a simple function representing a basic computation, perhaps part of a machine learning model update:
# Function: Computes element-wise tanh activation
def simple_activation(params, x):
# params could be weights or biases, here just a scalar scale
return jnp.tanh(params * x)
# Example single inputs
single_x = jnp.linspace(-2.0, 2.0, 5) # Shape (5,)
single_param = 1.5
# Calculate output for single input
single_output = simple_activation(single_param, single_x)
print(f"Single Input x:\n{single_x}")
print(f"Single Output:\n{single_output}")
# --- Using jit ---
# Compile the function
jit_activation = jax.jit(simple_activation)
jit_output = jit_activation(single_param, single_x)
# Ensure results match (ignoring potential float precision differences)
assert jnp.allclose(single_output, jit_output)
print("\nJIT output matches original.")
# --- Using grad ---
# Get the gradient function w.r.t the first argument (params)
grad_activation = jax.grad(simple_activation, argnums=0)
# Compute the gradient for the single input
single_grad = grad_activation(single_param, single_x)
print(f"Gradient w.r.t params (single input):\n{single_grad}")
# --- Combining vmap with jit and grad ---
# Now, create batched inputs for x, keeping params fixed
batch_x = jnp.stack([single_x, single_x * 0.5, single_x * -1.0]) # Shape (3, 5)
print(f"\nBatch Input x Shape: {batch_x.shape}")
# Option 1: vmap the original function, then jit
vectorized_activation = jax.vmap(simple_activation, in_axes=(None, 0)) # Broadcast param, map x
jit_vectorized_activation = jax.jit(vectorized_activation)
batch_output_1 = jit_vectorized_activation(single_param, batch_x)
print(f"Batch Output Shape (vmap -> jit): {batch_output_1.shape}")
# Option 2: jit the original function, then vmap
# (This is often preferred as jit sees the original structure first)
vectorized_jit_activation = jax.vmap(jit_activation, in_axes=(None, 0))
batch_output_2 = vectorized_jit_activation(single_param, batch_x)
print(f"Batch Output Shape (jit -> vmap): {batch_output_2.shape}")
assert jnp.allclose(batch_output_1, batch_output_2)
print("Outputs from both composition orders match.")
# Now, let's get batched gradients
# We want the gradient w.r.t 'params' for each item in the 'batch_x'
# We apply vmap *after* grad
# grad returns the gradient structure matching 'params' (a scalar here)
# vmap adds a batch dimension corresponding to the mapped input 'x'
vectorized_grad = jax.vmap(grad_activation, in_axes=(None, 0)) # Broadcast param, map x
batch_grads = vectorized_grad(single_param, batch_x)
print(f"\nBatch Gradients w.r.t params:\n{batch_grads}")
print(f"Batch Gradients Shape: {batch_grads.shape}")
# We can also JIT the vectorized gradient function for performance
jit_vectorized_grad = jax.jit(vectorized_grad)
start_time = timeit.default_timer()
batch_grads_jit = jit_vectorized_grad(single_param, batch_x).block_until_ready()
duration = timeit.default_timer() - start_time
print(f"JITted Batch Gradient Calculation Time: {duration:.6f}s")
assert jnp.allclose(batch_grads, batch_grads_jit)
This example shows how smoothly vmap
integrates. You can vmap
a function, jit
the result, or vmap
a grad
ded function. This composition allows you to write clear, single-instance logic and then efficiently apply it to batches using vmap
, differentiate it using grad
, and compile it using jit
.
Try applying vmap
to solve the following:
Given two sets of 2D points:
points_a = jnp.array([[0, 0], [1, 1]])
points_b = jnp.array([[2, 2], [3, 3], [4, 4]])
Write a function pairwise_distance(a, b)
that calculates the Euclidean distance between a single point a
and a single point b
. Then, use vmap
(possibly nested) to compute a matrix where the element (i, j)
is the distance between points_a[i]
and points_b[j]
.
The expected output shape should be (2, 3)
.
Hint: You might need one vmap
to handle iterating through points_a
and another (nested) vmap
to handle iterating through points_b
for each point in a
. Think about the in_axes
carefully for each level.
# --- Solution Sketch ---
def euclidean_distance(p1, p2):
# Calculates distance between two points (shape (2,))
return jnp.sqrt(jnp.sum((p1 - p2)**2))
points_a = jnp.array([[0., 0.], [1., 1.]]) # Shape (2, 2)
points_b = jnp.array([[2., 2.], [3., 3.], [4., 4.]]) # Shape (3, 2)
# Goal: Compute a (2, 3) matrix of distances dist(points_a[i], points_b[j])
# Hint 1: Vectorize over points_b first for a fixed point_a
# vmap_over_b = jax.vmap(euclidean_distance, in_axes=(None, 0))
# Try calling vmap_over_b(points_a[0], points_b) -> Expected shape (3,)
# Hint 2: Now vectorize the previous step over points_a
# vmap_over_a_and_b = jax.vmap(vmap_over_b, in_axes=(0, None))
# Try calling vmap_over_a_and_b(points_a, points_b) -> Expected shape (2, 3)
# --- Complete Solution ---
# Compute pairwise distances using nested vmap
@jax.jit # JIT the final computation for efficiency
def compute_pairwise_distances(arr_a, arr_b):
# vmap for inner loop (iterate through b for a fixed a)
vmap_dist_b = jax.vmap(euclidean_distance, in_axes=(None, 0))
# vmap for outer loop (iterate through a, applying vmap_dist_b to each)
vmap_dist_a_b = jax.vmap(vmap_dist_b, in_axes=(0, None))
return vmap_dist_a_b(arr_a, arr_b)
distance_matrix = compute_pairwise_distances(points_a, points_b)
print("\n--- Pairwise Distance Problem ---")
print(f"Points A:\n{points_a}")
print(f"Points B:\n{points_b}")
print(f"Pairwise Distance Matrix:\n{distance_matrix}")
print(f"Output shape: {distance_matrix.shape}")
# Expected output:
# [[2.828427 4.2426405 5.656854 ]
# [1.4142135 2.828427 4.2426405]]
These practical examples illustrate the flexibility and utility of jax.vmap
. By transforming functions written for single data points into functions that operate efficiently on batches, vmap
simplifies code, often improves performance, and integrates well with JAX's other transformations like jit
and grad
. As you work more with JAX, especially in machine learning contexts, vmap
will become an indispensable tool for handling batched data.
© 2025 ApX Machine Learning