While jit, grad, and vmap provide effective tools for accelerating and transforming numerical functions, they primarily operate on entire arrays or batches at once. Many important algorithms, particularly in areas like sequence modeling, signal processing, and optimization, involve sequential dependencies where the result of one step depends on the output of the previous one. Implementing these efficiently in JAX requires a specialized tool: jax.lax.scan.
Think of processing a time series, simulating a physical system step-by-step, or implementing the forward pass of a Recurrent Neural Network (RNN). A naive Python for loop iterating through the steps would work in pure Python, but it poses challenges for JAX's compilation model. When JAX traces a function containing a Python loop whose length depends on runtime values, it often has to "unroll" the loop during compilation. This means replicating the loop body's operations for each iteration in the compiled graph. For long sequences, this unrolling leads to very large computation graphs, significantly increasing compile times and potentially exceeding memory limits.
lax.scan provides a functional alternative designed specifically for these scenarios. It allows you to express sequential computations in a way that JAX can compile into highly efficient loop primitives on accelerators like GPUs and TPUs, avoiding the mistakes of explicit Python loops within jit-compiled functions.
Fundamentally, lax.scan works by repeatedly applying a function you define. This function takes two arguments at each step:
carry state: This holds information passed from the previous step to the current one. It's how state is maintained throughout the sequence processing.x from an input sequence xs: This is the specific input for the current step (optional).The function must return a tuple containing:
carry state: This will be passed to the next step.y: This is collected across all steps to form the final output sequence.The overall signature of lax.scan is:
jax.lax.scan(f, init, xs, length=None, reverse=False, unroll=1)
Let's break down the essential arguments:
f: The function to be applied repeatedly. It must have the signature f(carry, x) -> (new_carry, y). If xs is None, the signature is f(carry, None) -> (new_carry, y) or even f(carry) -> (new_carry, y) (implicitly taking None).init: The initial value of the carry state before the first step. Its structure must match the carry part of the output of f.xs: An optional PyTree (e.g., array, tuple/list/dict of arrays) representing the input sequence(s). lax.scan iterates over the leading axis of the arrays in xs. Each slice x passed to f will have a structure matching xs. If xs is None, scan iterates length times without per-step inputs.length: Optional integer specifying the number of steps. Usually inferred from the leading axis dimension of xs. Required if xs is None.reverse: Optional boolean. If True, scans from the end of xs to the beginning. Defaults to False.unroll: Optional integer. For performance tuning, suggests to the compiler how many loop iterations to unroll. Defaults to 1. Larger values might sometimes improve performance on certain hardware by reducing loop overhead, but can increase compile time and code size. Use with caution and profiling.lax.scan returns a tuple (final_carry, ys), where final_carry is the carry state after the last step, and ys is a PyTree containing the stacked outputs y from each step. The structure of ys matches the y part of the output of f, but with an added leading dimension corresponding to the sequence length.
Let's see how to implement a cumulative sum, where each element is the sum of all preceding elements in the input array, including itself.
import jax
import jax.numpy as jnp
import jax.lax as lax
# Define the scan function: f(carry, x) -> (new_carry, y)
# carry: the sum up to the previous element
# x: the current element
# new_carry: the sum including the current element (carry + x)
# y: the output for this step (also carry + x)
def cumulative_sum_step(carry_sum, current_x):
new_sum = carry_sum + current_x
return new_sum, new_sum # Return new carry and the output for this step
# Input array
input_array = jnp.array([1, 2, 3, 4, 5])
# Initial carry is 0 (sum before the first element)
initial_carry = 0
# Apply lax.scan
final_carry, result_sequence = lax.scan(cumulative_sum_step, initial_carry, input_array)
print("Input Array:", input_array)
print("Initial Carry:", initial_carry)
print("Final Carry (Total Sum):", final_carry)
print("Result Sequence (Cumulative Sum):", result_sequence)
# Expected output:
# Input Array: [1 2 3 4 5]
# Initial Carry: 0
# Final Carry (Total Sum): 15
# Result Sequence (Cumulative Sum): [ 1 3 6 10 15]
In this example:
initial_carry starts at 0.f(0, 1) returns (1, 1). carry becomes 1, ys starts collecting [1].f(1, 2) returns (3, 3). carry becomes 3, ys is [1, 3].f(3, 3) returns (6, 6). carry becomes 6, ys is [1, 3, 6].f(6, 4) returns (10, 10). carry becomes 10, ys is [1, 3, 6, 10].f(10, 5) returns (15, 15). carry becomes 15, ys is [1, 3, 6, 10, 15].lax.scan returns the final_carry (15) and the collected ys sequence ([1, 3, 6, 10, 15]).A more realistic application is implementing an RNN. A simple RNN updates its hidden state ht based on the previous hidden state ht−1 and the current input xt:
ht=tanh(Whhht−1+Wxhxt+bh)We can use lax.scan to iterate this update step over an input sequence. The carry will be the hidden state h, and xs will be the sequence of inputs x.
import jax
import jax.numpy as jnp
import jax.lax as lax
import jax.random as random
key = random.PRNGKey(0)
# Define dimensions
input_features = 3
hidden_features = 5
sequence_length = 4
# Initialize parameters (as part of the carry or closed over)
key, w_key, b_key = random.split(key, 3)
W_hh = random.normal(w_key, (hidden_features, hidden_features)) * 0.1
W_xh = random.normal(w_key, (hidden_features, input_features)) * 0.1
b_h = random.normal(b_key, (hidden_features,)) * 0.1
# Generate a dummy input sequence (sequence_length, input_features)
x_key = random.split(key)
input_sequence = random.normal(x_key, (sequence_length, input_features))
# Initial hidden state
h_initial = jnp.zeros((hidden_features,))
# Define the RNN step function: f(h_prev, x_t) -> (h_new, h_new)
# We output h_new as the per-step result 'y' as well
def rnn_step(h_prev, x_t):
h_new = jnp.tanh(jnp.dot(W_hh, h_prev) + jnp.dot(W_xh, x_t) + b_h)
# Return the new state as both the next carry and the output for this step
return h_new, h_new
# Apply lax.scan
final_hidden_state, hidden_states_sequence = lax.scan(rnn_step, h_initial, input_sequence)
print("Input sequence shape:", input_sequence.shape)
print("Initial hidden state shape:", h_initial.shape)
print("Final hidden state shape:", final_hidden_state.shape)
print("Sequence of hidden states shape:", hidden_states_sequence.shape)
# Expected output shapes:
# Input sequence shape: (4, 3)
# Initial hidden state shape: (5,)
# Final hidden state shape: (5,)
# Sequence of hidden states shape: (4, 5)
Here, rnn_step encapsulates the core recurrence relation. lax.scan efficiently applies this function across the input_sequence, starting with h_initial. The carry perfectly represents the evolving hidden state ht, and the collected ys (here, hidden_states_sequence) gives us the hidden state at each time step.
Notice how the parameters W_hh, W_xh, and b_h are defined outside rnn_step. JAX handles these closures automatically when tracing and compiling.
lax.scan Excels: Performance and CompilationThe primary advantage of lax.scan over a Python for loop inside @jit is performance, especially on accelerators.
lax.scan is translated by JAX into specialized XLA HLO (High Level Optimizer) loop operations (like While). XLA can then compile this loop representation very efficiently for the target hardware (GPU/TPU) without needing to duplicate the computation graph for every iteration. This keeps compile times manageable and reduces the size of the compiled code.lax.scan's compiled representation is typically much more compact. Furthermore, during execution, XLA can optimize memory usage within the loop, potentially reusing buffers more effectively than a fully unrolled graph might allow.scan, leading to faster execution.lax.scan and Other Transformationslax.scan is designed to compose with other JAX transformations:
jit: As discussed, lax.scan is ideal for use inside @jit. It provides the structure XLA needs for efficient compilation.grad: JAX can automatically differentiate through lax.scan. It computes gradients with respect to the init state, the parameters closed over by f, and the input sequence xs. This is fundamental for training RNNs and other sequence models. We will explore differentiation through control flow in more detail in Chapter 4.vmap: You can use vmap to run multiple independent scans in parallel, for example, processing a batch of sequences simultaneously. If you have batch_input_sequences with shape (batch_size, sequence_length, input_features), you could wrap lax.scan with vmap: jax.vmap(lambda seq: lax.scan(rnn_step, h_initial, seq))(batch_input_sequences). (Note: Handling batched initial states correctly might require vmapping over h_initial as well).xs=None)Sometimes, you might want to iterate a function based purely on the carry state, without consuming an external input sequence at each step. This is useful for generating sequences or running iterative processes for a fixed number of steps. You can achieve this by passing xs=None and specifying the length argument.
import jax
import jax.numpy as jnp
import jax.lax as lax
# Generate the first 10 powers of 2: 1, 2, 4, ...
# carry: the previous power of 2
# x: is None, ignored
# new_carry: the next power of 2 (carry * 2)
# y: the current power of 2 (carry)
def generate_powers_of_2(carry, _): # Use _ to indicate x is unused
next_val = carry * 2
return next_val, carry # Return next carry, output current value
initial_value = 1
num_steps = 10
final_val, powers_of_2 = lax.scan(generate_powers_of_2,
initial_value,
xs=None, # No input sequence
length=num_steps)
print("Final Value (2^10):", final_val)
print("Generated Powers of 2:", powers_of_2)
# Expected output:
# Final Value (2^10): 1024
# Generated Powers of 2: [ 1 2 4 8 16 32 64 128 256 512]
In summary, lax.scan is an indispensable tool in your advanced JAX toolkit. It provides the means to express complex sequential and recurrent computations in a way that is both functionally elegant and highly performant when compiled with jit, forming the basis for implementing many sophisticated models and algorithms on accelerated hardware.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with