jax.jitjitjitjitgradjax.gradgrad of grad)jax.value_and_grad)vmapjax.vmapin_axes, out_axes)vmapvmap with jit and gradvmappmapjax.pmapin_axes, out_axes)lax.psum, lax.pmean, etc.)pmap with other Transformationspmapped FunctionsvmapOften in numerical computation and machine learning, you need to apply the same function to multiple data points simultaneously, a process commonly referred to as batching. While you could write explicit loops or rely on functions designed for batch operations, these approaches can sometimes be verbose or require careful manual management of array dimensions.
JAX offers a transformation, jax.vmap, designed specifically for automatic vectorization. It allows you to take a function written to operate on a single data point and efficiently apply it across an entire batch (or multiple batches) of data, often without needing to rewrite the original function's logic. vmap effectively adds a "batch dimension" to your computations automatically.
In this chapter, you will learn:
jax.vmap to vectorize functions operating on single or multiple arguments.in_axes and out_axes arguments.vmap calls.vmap interacts with other JAX transformations like jit and grad.vmap effectively.By the end of this chapter, you'll be able to use vmap to simplify and often accelerate your batch processing code in JAX.
4.1 The Concept of Vectorization
4.2 Introducing `jax.vmap`
4.3 Mapping over Specific Arguments (`in_axes`, `out_axes`)
4.4 Handling Multiple Batched Arguments
4.5 Nesting `vmap`
4.6 Combining `vmap` with `jit` and `grad`
4.7 Performance Considerations for `vmap`
4.8 Hands-on Practical: Vectorizing Functions
© 2026 ApX Machine LearningEngineered with