jax.jitjitjitjitgradjax.gradgrad of grad)jax.value_and_grad)vmapjax.vmapin_axes, out_axes)vmapvmap with jit and gradvmappmapjax.pmapin_axes, out_axes)lax.psum, lax.pmean, etc.)pmap with other Transformationspmapped FunctionsWhile JAX provides a familiar NumPy-like API, standard Python execution can be slow for the intensive numerical computations common in scientific computing and machine learning, particularly when targeting accelerators like GPUs or TPUs. To achieve high performance, JAX relies on compilation.
This chapter focuses on jax.jit, the Just-In-Time compilation transformation. You will learn how jit converts Python functions into optimized executables tailored for your hardware. We will cover:
jax.jit as a decorator or function.jit.By the end of this chapter, you will be able to use jax.jit to significantly speed up your JAX computations.
2.1 The Need for Speed: Why Compile?
2.2 Introducing `jax.jit`
2.3 How JIT Works: Tracing and Compilation
2.4 Python Control Flow and `jit`
2.5 Static vs Traced Values
2.6 Common Challenges with `jit`
2.7 Hands-on Practical: Applying `jit`
© 2026 ApX Machine LearningEngineered with