JAX's strength lies not just in its individual transformations like jit
, grad
, and vmap
, but in their ability to compose seamlessly. Having explored the functional control flow primitives lax.scan
, lax.cond
, and lax.while_loop
, it's essential to understand how they interact with these core transformations. This composability is what allows you to build complex, high-performance models with features like recurrence and conditional logic, all within a unified, differentiable, and vectorizable framework.
jit
The primary motivation for using lax.scan
, lax.cond
, and lax.while_loop
is that they are designed specifically to be compatible with jit
. Unlike standard Python for
loops or if
statements, which can cause tracing issues or repeated recompilations if their behavior depends on input values, these lax
primitives present a static structure to the JAX tracer.
When you apply jit
to a function containing these primitives:
jaxpr
intermediate representation.jaxpr
and compiles the entire computational graph, including the control flow logic, into optimized low-level code (like HLO) for the target accelerator (GPU/TPU/CPU).This means the control flow itself is compiled, avoiding Python interpreter overhead during execution. For lax.cond
, XLA typically compiles both branches. While only one branch's result is selected based on the predicate during execution, the compilation ensures the code is ready for either outcome. Similarly, the body of lax.while_loop
and lax.scan
is compiled once, and the compiled code is executed iteratively.
Consider this simple example using lax.cond
:
import jax
import jax.numpy as jnp
import jax.lax as lax
def conditional_computation(x, threshold=5.0):
"""Applies different functions based on the sum of x."""
total = jnp.sum(x)
def true_fun(operand):
# Branch taken if total >= threshold
return operand * 2.0
def false_fun(operand):
# Branch taken if total < threshold
return operand / 2.0
# lax.cond selects which function to apply to x
return lax.cond(total >= threshold, true_fun, false_fun, x)
# JIT-compile the function
jitted_conditional_computation = jax.jit(conditional_computation)
# Example usage
data1 = jnp.array([1.0, 2.0, 3.0]) # sum = 6.0 >= 5.0 -> true_fun
data2 = jnp.array([1.0, 1.0, 1.0]) # sum = 3.0 < 5.0 -> false_fun
print("Result 1:", jitted_conditional_computation(data1))
print("Result 2:", jitted_conditional_computation(data2))
# Both branches were compiled, execution selects the appropriate one.
The key takeaway is that lax
control flow primitives allow dynamic runtime behavior based on values, while presenting a static structure for compile-time analysis by jit
and XLA.
grad
Automatic differentiation (grad
, vjp
, jvp
) composes naturally with lax
control flow. JAX tracks operations through the executed paths to compute gradients correctly.
grad
and lax.scan
: This combination is fundamental for training recurrent models. When you differentiate a function containing lax.scan
, JAX effectively applies the chain rule back through the sequential steps performed by the scan. This is analogous to Backpropagation Through Time (BPTT) in traditional RNN frameworks. Keep in mind that storing intermediate activations for the backward pass can consume significant memory, especially for long sequences. Techniques like gradient checkpointing (covered later) can mitigate this.
import jax
import jax.numpy as jnp
import jax.lax as lax
def simple_rnn_step(carry, x_t):
"""A very basic RNN step."""
prev_hidden = carry
# Simple update: new_hidden = tanh(W*x_t + U*prev_hidden + b)
# For simplicity, assume W, U, b are fixed scalars here
W, U, b = 0.5, 0.8, 0.1
new_hidden = jnp.tanh(W * x_t + U * prev_hidden + b)
return new_hidden, new_hidden # carry = new_hidden, output = new_hidden
def run_rnn(initial_hidden, inputs):
"""Runs the RNN over a sequence of inputs."""
final_hidden, outputs = lax.scan(simple_rnn_step, initial_hidden, inputs)
return jnp.sum(outputs) # Example objective: sum of outputs
# Compute the gradient of the objective wrt the initial hidden state
grad_run_rnn = jax.grad(run_rnn, argnums=0)
# Example data
hidden_init = 0.0
input_sequence = jnp.array([1.0, -0.5, 2.0])
gradient_wrt_h0 = grad_run_rnn(hidden_init, input_sequence)
print(f"Gradient w.r.t. initial hidden state: {gradient_wrt_h0}")
# Output: Gradient w.r.t. initial hidden state: 0.313...
Here, jax.grad
computes how the final sum of outputs changes with respect to the initial_hidden
state by propagating gradients back through the lax.scan
steps.
grad
and lax.cond
: Differentiation proceeds through the branch that was executed during the forward pass. The computations within the branch not taken do not contribute to the gradient of the output with respect to the input for that specific execution. JAX handles the selection process correctly. If the condition itself depends differentiably on the function inputs, its derivative will also be incorporated.
def conditional_loss(params, x):
# Condition depends on input x
pred = jnp.sum(x) > 0
def loss1(p): # If pred is True
return jnp.sum(p * x)
def loss2(p): # If pred is False
return jnp.sum(p / (x + 1e-5)) # Avoid division by zero
return lax.cond(pred, loss1, loss2, params)
grad_conditional_loss = jax.grad(conditional_loss)
params = jnp.array([0.5, -0.5])
data_pos = jnp.array([1.0, 1.0]) # sum > 0 -> loss1
data_neg = jnp.array([-1.0, -1.0]) # sum <= 0 -> loss2
print("Grad (pos):", grad_conditional_loss(params, data_pos))
# Output: Grad (pos): [1. 1.] (Gradient comes from loss1: d(p*x)/dp = x)
print("Grad (neg):", grad_conditional_loss(params, data_neg))
# Output: Grad (neg): [-0.99999 -0.99999] (Gradient comes from loss2: d(p/x)/dp = 1/x)
grad
and lax.while_loop
: Similar to lax.scan
, differentiation unrolls the loop iterations performed during the forward pass and applies the chain rule. The number of iterations can depend on the input values. Again, be mindful of potential memory usage if the loop runs for many iterations. The gradient calculation correctly accounts for the computations within the loop body and potentially the condition function if it depends on differentiated variables.
vmap
Vectorizing functions that contain control flow using vmap
is powerful but requires careful consideration. vmap
pushes the mapped axis into the computation.
vmap
and lax.scan
: This is a common pattern, for example, when processing a batch of sequences simultaneously with an RNN. vmap
typically transforms the lax.scan
to operate over the batch dimension specified by in_axes
. Each element in the batch undergoes its own independent scan.
# Using simple_rnn_step and run_rnn from the grad example
# Vectorize run_rnn over a batch of initial hidden states and input sequences
# Assume hidden_init is shape (batch,) and inputs is shape (batch, seq_len)
# We map over axis 0 for both arguments
batched_run_rnn = jax.vmap(run_rnn, in_axes=(0, 0))
# Example batched data
batch_size = 4
seq_len = 3
batch_hidden_init = jnp.zeros(batch_size)
batch_input_sequence = jnp.arange(batch_size * seq_len, dtype=jnp.float32).reshape((batch_size, seq_len))
# Run the batched RNN
batched_output_sum = batched_run_rnn(batch_hidden_init, batch_input_sequence)
print("Batched RNN output sum shape:", batched_output_sum.shape)
# Output: Batched RNN output sum shape: (4,)
# Each element corresponds to the sum of outputs for one sequence in the batch
vmap
and lax.cond
: This interaction can be more complex if the condition predicate depends on the mapped axis. If the condition evaluates differently for different elements along the mapped dimension, vmap
needs to handle executing potentially different branches for different "lanes" of the vectorized computation. JAX achieves this by effectively evaluating both branches across the mapped axis and then selecting the appropriate result for each lane based on its corresponding predicate value. This implies that the computational cost might be higher than if all elements took the same branch, as both paths are processed.
# Using conditional_computation from the jit example
# Vectorize over x, keeping threshold scalar
# map over axis 0 of x
vmapped_conditional = jax.vmap(conditional_computation, in_axes=(0, None))
# Batch of data where condition differs
batch_data = jnp.array([
[1.0, 2.0, 3.0], # sum = 6.0 >= 5.0 -> true_fun
[1.0, 1.0, 1.0], # sum = 3.0 < 5.0 -> false_fun
[10.0, 1.0, 1.0], # sum = 12.0 >= 5.0 -> true_fun
])
threshold = 5.0
results = vmapped_conditional(batch_data, threshold)
print("Vmapped conditional results:\n", results)
# Output:
# Vmapped conditional results:
# [[ 2. 4. 6.] <- true_fun applied
# [ 0.5 0.5 0.5] <- false_fun applied
# [20. 2. 2.]] <- true_fun applied
Even though different rows took different paths, vmap
+ lax.cond
handled it.
vmap
and lax.while_loop
: Vectorizing a while_loop
behaves similarly to lax.cond
. If the loop condition or the loop body's effect on the state depends on the mapped axis, different lanes might execute for a different number of iterations. JAX/XLA handles this, often involving mechanisms like masking inactive lanes or running all lanes for the maximum number of iterations observed across the batch. This can lead to computational work being done on lanes that have already met their exit condition, impacting performance compared to a scenario where all lanes iterate the same number of times.
You can freely compose these transformations. For example, you can take the gradient of a vectorized function that uses a scan: jax.grad(jax.vmap(run_rnn))
. Or you can JIT-compile the gradient of a function with conditionals: jax.jit(jax.grad(conditional_loss))
. JAX's functional nature and the design of its primitives ensure these compositions are well-defined.
# Example: JIT-compiled gradient of a vmapped conditional computation
# Use conditional_computation from above
# Function where params affect the threshold
def threshold_from_params(params, x):
# threshold = average of params
threshold = jnp.mean(params)
return conditional_computation(x, threshold)
# Target: Compute gradient w.r.t params for a batch of x
# 1. Vectorize over x (axis 0)
# 2. Compute gradient w.r.t params (arg 0)
# 3. JIT the result
grad_vmap_fn = jax.grad(jax.vmap(threshold_from_params, in_axes=(None, 0)), argnums=0)
jitted_grad_vmap_fn = jax.jit(grad_vmap_fn)
# Example data
batch_data = jnp.array([
[1.0, 2.0, 3.0], # sum = 6.0
[1.0, 1.0, 1.0], # sum = 3.0
])
params = jnp.array([4.0, 6.0]) # mean = 5.0
gradient = jitted_grad_vmap_fn(params, batch_data)
print("Jitted gradient of vmapped conditional function w.r.t params:\n", gradient)
# The gradient calculation depends on which branch was taken for each item in the batch,
# influenced by the params-derived threshold.
Understanding how jit
, grad
, and vmap
interact with lax
control flow primitives is fundamental to writing efficient and correct advanced JAX code. It allows you to build sophisticated models that require sequential processing, conditional logic, or dynamic iteration counts, while still benefiting from compilation, automatic differentiation, and vectorization. Be mindful of potential performance or memory implications, particularly when vectorizing diverging control flow or differentiating through long scans or loops.
© 2025 ApX Machine Learning