jax.jit
jit
jit
jit
grad
jax.grad
grad
of grad
)jax.value_and_grad
)vmap
jax.vmap
in_axes
, out_axes
)vmap
vmap
with jit
and grad
vmap
pmap
jax.pmap
in_axes
, out_axes
)lax.psum
, lax.pmean
, etc.)pmap
with other Transformationspmap
ped FunctionsWhile JAX provides a familiar NumPy-like API, standard Python execution can be slow for the intensive numerical computations common in scientific computing and machine learning, particularly when targeting accelerators like GPUs or TPUs. To achieve high performance, JAX relies on compilation.
This chapter focuses on jax.jit
, the Just-In-Time compilation transformation. You will learn how jit
converts Python functions into optimized executables tailored for your hardware. We will cover:
jax.jit
as a decorator or function.jit
.By the end of this chapter, you will be able to use jax.jit
to significantly speed up your JAX computations.
2.1 The Need for Speed: Why Compile?
2.2 Introducing `jax.jit`
2.3 How JIT Works: Tracing and Compilation
2.4 Python Control Flow and `jit`
2.5 Static vs Traced Values
2.6 Common Challenges with `jit`
2.7 Hands-on Practical: Applying `jit`
© 2025 ApX Machine Learning