While jit
, grad
, and vmap
provide powerful tools for accelerating and transforming numerical functions, they primarily operate on entire arrays or batches at once. Many important algorithms, particularly in areas like sequence modeling, signal processing, and optimization, involve sequential dependencies where the result of one step depends on the output of the previous one. Implementing these efficiently in JAX requires a specialized tool: jax.lax.scan
.
Think of processing a time series, simulating a physical system step-by-step, or implementing the forward pass of a Recurrent Neural Network (RNN). A naive Python for
loop iterating through the steps would work in pure Python, but it poses challenges for JAX's compilation model. When JAX traces a function containing a Python loop whose length depends on runtime values, it often has to "unroll" the loop during compilation. This means replicating the loop body's operations for each iteration in the compiled graph. For long sequences, this unrolling leads to very large computation graphs, significantly increasing compile times and potentially exceeding memory limits.
lax.scan
provides a functional alternative designed specifically for these scenarios. It allows you to express sequential computations in a way that JAX can compile into highly efficient loop primitives on accelerators like GPUs and TPUs, avoiding the pitfalls of explicit Python loops within jit
-compiled functions.
At its heart, lax.scan
works by repeatedly applying a function you define. This function takes two arguments at each step:
carry
state: This holds information passed from the previous step to the current one. It's how state is maintained throughout the sequence processing.x
from an input sequence xs
: This is the specific input for the current step (optional).The function must return a tuple containing:
carry
state: This will be passed to the next step.y
: This is collected across all steps to form the final output sequence.The overall signature of lax.scan
is:
jax.lax.scan(f, init, xs, length=None, reverse=False, unroll=1)
Let's break down the essential arguments:
f
: The function to be applied repeatedly. It must have the signature f(carry, x) -> (new_carry, y)
. If xs
is None
, the signature is f(carry, None) -> (new_carry, y)
or even f(carry) -> (new_carry, y)
(implicitly taking None
).init
: The initial value of the carry
state before the first step. Its structure must match the carry
part of the output of f
.xs
: An optional PyTree (e.g., array, tuple/list/dict of arrays) representing the input sequence(s). lax.scan
iterates over the leading axis of the arrays in xs
. Each slice x
passed to f
will have a structure matching xs
. If xs
is None
, scan
iterates length
times without per-step inputs.length
: Optional integer specifying the number of steps. Usually inferred from the leading axis dimension of xs
. Required if xs
is None
.reverse
: Optional boolean. If True
, scans from the end of xs
to the beginning. Defaults to False
.unroll
: Optional integer. For performance tuning, suggests to the compiler how many loop iterations to unroll. Defaults to 1. Larger values might sometimes improve performance on certain hardware by reducing loop overhead, but can increase compile time and code size. Use with caution and profiling.lax.scan
returns a tuple (final_carry, ys)
, where final_carry
is the carry state after the last step, and ys
is a PyTree containing the stacked outputs y
from each step. The structure of ys
matches the y
part of the output of f
, but with an added leading dimension corresponding to the sequence length.
Let's see how to implement a cumulative sum, where each element is the sum of all preceding elements in the input array, including itself.
import jax
import jax.numpy as jnp
import jax.lax as lax
# Define the scan function: f(carry, x) -> (new_carry, y)
# carry: the sum up to the previous element
# x: the current element
# new_carry: the sum including the current element (carry + x)
# y: the output for this step (also carry + x)
def cumulative_sum_step(carry_sum, current_x):
new_sum = carry_sum + current_x
return new_sum, new_sum # Return new carry and the output for this step
# Input array
input_array = jnp.array([1, 2, 3, 4, 5])
# Initial carry is 0 (sum before the first element)
initial_carry = 0
# Apply lax.scan
final_carry, result_sequence = lax.scan(cumulative_sum_step, initial_carry, input_array)
print("Input Array:", input_array)
print("Initial Carry:", initial_carry)
print("Final Carry (Total Sum):", final_carry)
print("Result Sequence (Cumulative Sum):", result_sequence)
# Expected output:
# Input Array: [1 2 3 4 5]
# Initial Carry: 0
# Final Carry (Total Sum): 15
# Result Sequence (Cumulative Sum): [ 1 3 6 10 15]
In this example:
initial_carry
starts at 0
.f(0, 1)
returns (1, 1)
. carry
becomes 1
, ys
starts collecting [1]
.f(1, 2)
returns (3, 3)
. carry
becomes 3
, ys
is [1, 3]
.f(3, 3)
returns (6, 6)
. carry
becomes 6
, ys
is [1, 3, 6]
.f(6, 4)
returns (10, 10)
. carry
becomes 10
, ys
is [1, 3, 6, 10]
.f(10, 5)
returns (15, 15)
. carry
becomes 15
, ys
is [1, 3, 6, 10, 15]
.lax.scan
returns the final_carry
(15
) and the collected ys
sequence ([1, 3, 6, 10, 15]
).A more realistic application is implementing an RNN. A simple RNN updates its hidden state ht based on the previous hidden state ht−1 and the current input xt:
ht=tanh(Whhht−1+Wxhxt+bh)We can use lax.scan
to iterate this update step over an input sequence. The carry
will be the hidden state h, and xs
will be the sequence of inputs x.
import jax
import jax.numpy as jnp
import jax.lax as lax
import jax.random as random
key = random.PRNGKey(0)
# Define dimensions
input_features = 3
hidden_features = 5
sequence_length = 4
# Initialize parameters (as part of the carry or closed over)
key, w_key, b_key = random.split(key, 3)
W_hh = random.normal(w_key, (hidden_features, hidden_features)) * 0.1
W_xh = random.normal(w_key, (hidden_features, input_features)) * 0.1
b_h = random.normal(b_key, (hidden_features,)) * 0.1
# Generate a dummy input sequence (sequence_length, input_features)
key, x_key = random.split(key)
input_sequence = random.normal(x_key, (sequence_length, input_features))
# Initial hidden state
h_initial = jnp.zeros((hidden_features,))
# Define the RNN step function: f(h_prev, x_t) -> (h_new, h_new)
# We output h_new as the per-step result 'y' as well
def rnn_step(h_prev, x_t):
h_new = jnp.tanh(jnp.dot(W_hh, h_prev) + jnp.dot(W_xh, x_t) + b_h)
# Return the new state as both the next carry and the output for this step
return h_new, h_new
# Apply lax.scan
final_hidden_state, hidden_states_sequence = lax.scan(rnn_step, h_initial, input_sequence)
print("Input sequence shape:", input_sequence.shape)
print("Initial hidden state shape:", h_initial.shape)
print("Final hidden state shape:", final_hidden_state.shape)
print("Sequence of hidden states shape:", hidden_states_sequence.shape)
# Expected output shapes:
# Input sequence shape: (4, 3)
# Initial hidden state shape: (5,)
# Final hidden state shape: (5,)
# Sequence of hidden states shape: (4, 5)
Here, rnn_step
encapsulates the core recurrence relation. lax.scan
efficiently applies this function across the input_sequence
, starting with h_initial
. The carry
perfectly represents the evolving hidden state ht, and the collected ys
(here, hidden_states_sequence
) gives us the hidden state at each time step.
Notice how the parameters W_hh
, W_xh
, and b_h
are defined outside rnn_step
. JAX handles these closures automatically when tracing and compiling.
lax.scan
Excels: Performance and CompilationThe primary advantage of lax.scan
over a Python for
loop inside @jit
is performance, especially on accelerators.
lax.scan
is translated by JAX into specialized XLA HLO (High Level Optimizer) loop operations (like While
). XLA can then compile this loop representation very efficiently for the target hardware (GPU/TPU) without needing to duplicate the computation graph for every iteration. This keeps compile times manageable and reduces the size of the compiled code.lax.scan
's compiled representation is typically much more compact. Furthermore, during execution, XLA can optimize memory usage within the loop, potentially reusing buffers more effectively than a fully unrolled graph might allow.scan
, leading to faster execution.lax.scan
and Other Transformationslax.scan
is designed to compose seamlessly with other JAX transformations:
jit
: As discussed, lax.scan
is ideal for use inside @jit
. It provides the structure XLA needs for efficient compilation.grad
: JAX can automatically differentiate through lax.scan
. It computes gradients with respect to the init
state, the parameters closed over by f
, and the input sequence xs
. This is fundamental for training RNNs and other sequence models. We will explore differentiation through control flow in more detail in Chapter 4.vmap
: You can use vmap
to run multiple independent scans in parallel, for example, processing a batch of sequences simultaneously. If you have batch_input_sequences
with shape (batch_size, sequence_length, input_features)
, you could wrap lax.scan
with vmap
: jax.vmap(lambda seq: lax.scan(rnn_step, h_initial, seq))(batch_input_sequences)
. (Note: Handling batched initial states correctly might require vmap
ping over h_initial
as well).xs=None
)Sometimes, you might want to iterate a function based purely on the carry state, without consuming an external input sequence at each step. This is useful for generating sequences or running iterative processes for a fixed number of steps. You can achieve this by passing xs=None
and specifying the length
argument.
import jax
import jax.numpy as jnp
import jax.lax as lax
# Generate the first 10 powers of 2: 1, 2, 4, ...
# carry: the previous power of 2
# x: is None, ignored
# new_carry: the next power of 2 (carry * 2)
# y: the current power of 2 (carry)
def generate_powers_of_2(carry, _): # Use _ to indicate x is unused
next_val = carry * 2
return next_val, carry # Return next carry, output current value
initial_value = 1
num_steps = 10
final_val, powers_of_2 = lax.scan(generate_powers_of_2,
initial_value,
xs=None, # No input sequence
length=num_steps)
print("Final Value (2^10):", final_val)
print("Generated Powers of 2:", powers_of_2)
# Expected output:
# Final Value (2^10): 1024
# Generated Powers of 2: [ 1 2 4 8 16 32 64 128 256 512]
In summary, lax.scan
is an indispensable tool in your advanced JAX toolkit. It provides the means to express complex sequential and recurrent computations in a way that is both functionally elegant and highly performant when compiled with jit
, forming the basis for implementing many sophisticated models and algorithms on accelerated hardware.
© 2025 ApX Machine Learning