Before we examine JAX's advanced control flow mechanisms like lax.scan
and lax.cond
, let's briefly revisit the foundational transformations that form the bedrock of JAX programming: jax.jit
, jax.grad
, and jax.vmap
. As this course targets developers with a solid JAX background, this review focuses on reinforcing the core concepts and operational characteristics crucial for understanding their interaction with more complex constructs. These transformations operate on functions, taking Python functions as input and returning new, transformed Python functions.
The jax.jit
transformation accelerates your Python functions, particularly those involving numerical computations typical in machine learning, by compiling them using Google's Accelerated Linear Algebra (XLA) compiler.
The first time a jit
-compiled function is called with specific input shapes and types, JAX performs tracing. During tracing, JAX executes the Python code not with actual numerical values, but with abstract tracer objects. These tracers record the sequence of primitive operations performed. This recorded sequence, known as the jaxpr (JAX Program Representation), captures the computation graph. XLA then takes this jaxpr and compiles it into highly optimized machine code tailored for the target hardware (CPU, GPU, or TPU). Subsequent calls with matching input shapes and types directly execute this pre-compiled code, bypassing the Python interpreter overhead and benefiting from XLA optimizations like operator fusion.
A simplified view of the
jax.jit
compilation process.
This compilation process imposes a significant constraint: functions targeted by jit
must be functionally pure regarding the traced operations. This means they should not have side effects (like printing or modifying external state) that depend on the tracer values during execution, as these side effects occur only once during tracing. The function's output must depend solely on its explicit inputs.
import jax
import jax.numpy as jnp
# A simple function
def slow_f(x):
# Simulate some computation
return jnp.sin(x) * x + jnp.log(x + 1)
# Apply jit
fast_f = jax.jit(slow_f)
# First call: Tracing and compilation happen
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (1000, 1000))
result1 = fast_f(x)
result1.block_until_ready() # Ensure computation finishes for timing
# Second call: Uses compiled kernel (much faster)
result2 = fast_f(x)
result2.block_until_ready()
print(f"Results are close: {jnp.allclose(result1, result2)}")
# Output: Results are close: True
While extremely powerful, be mindful that changes in input array shapes or dtypes, or changes in the values of Python constants closed over by the function, can trigger recompilation. We'll explore strategies for minimizing recompilation in Chapter 2.
Automatic differentiation is central to modern machine learning for optimizing model parameters. jax.grad
provides JAX's primary reverse-mode automatic differentiation capability. Given a Python function that computes a scalar output, jax.grad
returns a new function that computes the gradient of the original function with respect to one of its arguments (by default, the first).
def predict(params, x):
w, b = params
return jnp.dot(w, x) + b
def loss_fn(params, x, y_true):
y_pred = predict(params, x)
# Simple squared error
return jnp.mean((y_pred - y_true)**2)
# Get the function that computes the gradient w.r.t. 'params' (arg 0)
grad_loss_fn = jax.grad(loss_fn, argnums=0)
# Example data
w_init = jnp.array([1.5, -0.5])
b_init = jnp.array(0.3)
params_init = (w_init, b_init)
x_data = jnp.array([0.2, 0.8])
y_target = jnp.array(2.5)
# Compute the gradients
gradients = grad_loss_fn(params_init, x_data, y_target)
print(f"Gradient w.r.t w: {gradients[0]}")
print(f"Gradient w.r.t b: {gradients[1]}")
# Example output (exact values depend on computation):
# Gradient w.r.t w: [-0.392 -1.568]
# Gradient w.r.t b: -1.96
Under the hood, jax.grad
is built upon vector-Jacobian products (VJPs), which we will examine in detail in Chapter 4. Its compositional nature allows for easy computation of higher-order derivatives by applying grad
multiple times. Remember that grad
expects the differentiated function to return a single scalar value. For functions returning multiple values or non-scalar arrays, different techniques or combinations with vmap
are needed to compute full Jacobians or Hessians.
jax.vmap
is a vectorizing map. It transforms a function written to operate on single data points into one that operates efficiently on batches or axes of data, without requiring manual loops in Python. This is achieved by adding a batch dimension ("mapping" the function over that dimension).
The core idea is to specify which input arguments have a batch dimension and how that dimension should be mapped. jax.vmap(fun, in_axes, out_axes)
takes the function fun
and arguments specifying the mapped axes:
in_axes
: A tuple/list/pytree indicating which axis of each input argument should be mapped over. None
means the argument is broadcast. 0
typically means mapping over the first axis.out_axes
: Specifies where the mapped axis should appear in the output.# Function works on single vectors
def simple_affine(w, b, x):
# w: matrix [out_dim, in_dim]
# b: vector [out_dim]
# x: vector [in_dim]
return jnp.dot(w, x) + b
# Example parameters
w = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) # Shape (3, 2)
b = jnp.array([0.1, 0.2, 0.3]) # Shape (3,)
# Batch of inputs (4 vectors of size 2)
x_batch = jnp.arange(8.).reshape(4, 2) # Shape (4, 2)
# Vectorize `simple_affine` over the batch dimension (axis 0) of `x_batch`.
# `w` and `b` are not mapped (broadcasted).
batched_affine = jax.vmap(simple_affine, in_axes=(None, None, 0))
# Apply the vectorized function
y_batch = batched_affine(w, b, x_batch)
print(f"Input batch shape: {x_batch.shape}")
print(f"Output batch shape: {y_batch.shape}")
# Output:
# Input batch shape: (4, 2)
# Output batch shape: (4, 3)
vmap
is essential for processing batches of data in ML, avoiding slow Python for
loops and leveraging parallel hardware capabilities. It can be arbitrarily nested and composed with jit
and grad
, enabling sophisticated batched computations and gradient calculations.
The true utility of JAX emerges when these transformations are composed. You can jit
a gradient function (jit(grad(f))
), vectorize a compiled function (vmap(jit(f))
), or compute batched gradients (jit(vmap(grad(f)))
), and so on. The order of composition matters and affects the resulting computation.
# Example: JIT-compiled, batched gradient computation
batched_grad_loss_fn = jax.jit(jax.vmap(grad_loss_fn, in_axes=(None, 0, 0)))
# Generate batch data
key, subkey = jax.random.split(key)
x_batch_data = jax.random.normal(subkey, (16, 2)) # Batch of 16
y_batch_target = jax.random.normal(key, (16,)) # Batch of 16
# Compute gradients for the entire batch efficiently
batch_gradients = batched_grad_loss_fn(params_init, x_batch_data, y_batch_target)
print(f"Batch Gradient w.r.t w shape: {batch_gradients[0].shape}")
print(f"Batch Gradient w.r.t b shape: {batch_gradients[1].shape}")
# Output:
# Batch Gradient w.r.t w shape: (16, 2)
# Batch Gradient w.r.t b shape: (16,)
Understanding how jit
traces code, how grad
propagates derivatives, and how vmap
manipulates batch dimensions is fundamental. This understanding becomes even more significant when introducing control flow primitives like lax.scan
, lax.cond
, and lax.while_loop
, as these primitives interact with the tracing and transformation mechanisms in specific ways, which we will explore in the subsequent sections of this chapter.
© 2025 ApX Machine Learning