As mentioned in the chapter introduction, applying the same operation to many different inputs is a frequent requirement in scientific computing and machine learning. Think about processing a batch of images, running simulations with different parameters, or calculating loss contributions for multiple examples in a training set.
A direct way to handle this in Python is using a standard for
loop. For instance, if you have a function predict(image)
that works on a single image, you might process a batch like this:
def predict_batch_loop(images):
results = []
for image in images:
results.append(predict(image))
return results # Or stack them into an array
While straightforward, this approach often suffers from performance bottlenecks. Each iteration of the Python loop carries overhead, and the operations inside might not efficiently utilize the parallel processing capabilities of modern hardware like GPUs or TPUs. Python's interpreted nature can make loops significantly slower than operations executed by optimized, compiled code.
Vectorization offers a more efficient alternative. At its core, vectorization is the technique of adapting scalar operations (or operations on small, fixed-size inputs) to work concurrently on entire arrays or batches of data. Instead of processing elements one by one in a loop, a vectorized operation performs the computation across multiple data points simultaneously, at least from the perspective of the high-level code.
Consider adding a constant value to a list of numbers. The looping approach iterates:
# Looping approach
numbers = [1, 2, 3, 4, 5]
results = []
for x in numbers:
results.append(x + 10)
# results: [11, 12, 13, 14, 15]
A vectorized approach, typical in libraries like NumPy, looks like this:
# Vectorized approach (NumPy example)
import numpy as np
numbers = np.array([1, 2, 3, 4, 5])
results = numbers + 10
# results: array([11, 12, 13, 14, 15])
Here, numbers + 10
applies the addition across the entire numpy
array without an explicit Python loop. Under the hood, NumPy (and similar libraries) delegates this operation to highly optimized, pre-compiled code (often written in C or Fortran) that can process large chunks of data much faster than the Python interpreter.
Difference between processing items sequentially in a loop versus simultaneously in a vectorized manner.
The primary motivation for vectorization is performance.
Beyond speed, vectorized code can often be more concise and readable. The numbers + 10
example is arguably simpler and expresses the intent more directly than the explicit for
loop.
While libraries like NumPy provide functions that are vectorized (like np.add
, np.sin
, etc.), they often require you to structure your input data and write your code in a way that naturally aligns with these functions (e.g., ensuring array dimensions are compatible for broadcasting).
JAX takes this a step further with jax.vmap
. The idea behind vmap
is to let you write your function logic as if it operates on a single data point (like predict(image)
). Then, you can use vmap
to automatically transform that function into a new function that efficiently handles batches of data, without manually rewriting the core logic for batch processing or managing loops. It essentially adds the "batch dimension" automatically, mapping the computation across slices of your input arrays.
Understanding this concept of vectorization is fundamental to appreciating how vmap
works and why it's such a useful tool in the JAX ecosystem for simplifying code and boosting performance. In the following sections, we'll see how to put jax.vmap
into practice.
© 2025 ApX Machine Learning