While lax.scan
excels at executing a fixed number of sequential steps, many algorithms require loops that continue until a specific condition is met. Standard Python while
loops present a challenge for JAX's compilation process. Because JAX traces the function to generate a computation graph before execution, it needs to know the structure of the computation upfront. A Python while
loop whose termination condition depends on intermediate values computed within the loop cannot be directly traced and JIT-compiled effectively.
To handle such dynamic looping structures within compiled code, JAX provides jax.lax.while_loop
. This function allows you to express loops where the number of iterations is determined dynamically based on runtime values, while still enabling JIT compilation and execution on accelerators.
lax.while_loop
The lax.while_loop
function takes three main arguments:
cond_fun
: A Python callable that takes the current loop state (the "carry") and returns a boolean scalar JAX value. The loop continues as long as cond_fun
returns True
.body_fun
: A Python callable that defines the operations performed in a single iteration. It takes the current loop state (carry) as input and must return an updated loop state with the same structure (shape and dtype).init_val
: The initial state or "carry" value provided to the first iteration of the loop.The basic signature looks like this:
final_val = jax.lax.while_loop(cond_fun, body_fun, init_val)
The loop proceeds as follows:
cond_fun
is called with init_val
.True
, body_fun
is called with init_val
. It returns next_val
.cond_fun
is called with next_val
.True
, body_fun
is called with next_val
. It returns another_val
.cond_fun
returns False
.body_fun
(the one passed to cond_fun
which resulted in False
) is the final result returned by lax.while_loop
.Crucially, like other JAX transformations and control flow primitives, lax.while_loop
operates functionally. The loop state is not modified in place; instead, the body_fun
must explicitly return the new state for the next iteration. This state (init_val
and subsequent outputs of body_fun
) can be any JAX-compatible type, including scalars, arrays, or pytrees (nested tuples, lists, dictionaries) of arrays. The structure and types/shapes of the state must remain consistent across iterations.
Let's illustrate with a simple example: finding the smallest power of 2 that is greater than or equal to a given number n
.
import jax
import jax.numpy as jnp
def find_power_of_two(n):
# Condition function: continue while current_power < n
def cond_fun(loop_state):
current_power = loop_state
return current_power < n
# Body function: double the current power
def body_fun(loop_state):
current_power = loop_state
return current_power * 2
# Initial state: start with power = 1
init_val = 1
# Run the while loop
final_power = jax.lax.while_loop(cond_fun, body_fun, init_val)
return final_power
# Example usage
target_number = 100
result = find_power_of_two(target_number)
print(f"Smallest power of 2 >= {target_number}: {result}")
# Output: Smallest power of 2 >= 100: 128
# We can JIT-compile the function
jit_find_power_of_two = jax.jit(find_power_of_two)
result_jit = jit_find_power_of_two(target_number)
print(f"JIT result: {result_jit}")
# Output: JIT result: 128
In this example:
init_val
is 1
.cond_fun
checks if the current power (loop_state
) is less than n
(100).body_fun
takes the current power and returns the doubled value as the new state.cond_fun(1)
-> True. body_fun(1)
-> 2.cond_fun(2)
-> True. body_fun(2)
-> 4.cond_fun(64)
-> True. body_fun(64)
-> 128.cond_fun(128)
-> False. Loop terminates.lax.while_loop
returns the last state passed to cond_fun
that caused it to return False
, which is 128
.When JAX encounters lax.while_loop
inside a function being JIT-compiled, it traces both cond_fun
and body_fun
once to understand the operations they perform. The shapes and dtypes of the loop state (init_val
and the return value of body_fun
) must be static and determinable during tracing. The values can change dynamically, but the structure cannot.
The entire loop is then compiled into a single optimized operation (often a while
operation in the underlying XLA representation). This is fundamentally different from a Python while
loop, which would typically cause JAX to either unroll the loop if the number of iterations is fixed and known during tracing or fail compilation if the condition depends on traced values. lax.while_loop
explicitly tells JAX that this is a dynamic loop structure intended for compilation.
lax.scan
It's helpful to contrast lax.while_loop
with lax.scan
:
lax.scan
executes a fixed number of iterations, determined by the length of the input sequence(s). lax.while_loop
executes a variable number of iterations until cond_fun
returns False
.lax.scan
typically accumulates outputs from each step. lax.while_loop
only returns the final state of the loop carry.lax.scan
is ideal for processing sequences (like in RNNs) or applying an operation repeatedly for a known count. lax.while_loop
is suited for iterative algorithms with convergence criteria or simulations running until a condition is met.cond_fun
will eventually become False
. An infinite loop in lax.while_loop
will cause the compiled program to hang. JAX cannot statically verify termination for all possible loops.lax.while_loop
enables dynamic iteration counts, loops with very large or highly variable iteration counts might have performance implications compared to statically sized operations or lax.scan
where the total amount of work is known at compile time. Loop overhead exists, though XLA optimizes it well.lax.while_loop
supports automatic differentiation (jax.grad
). Differentiating through a while_loop
typically involves unrolling the loop iterations during the backward pass, which can have memory implications depending on the number of iterations executed. This interaction is explored further in Chapter 4.lax.while_loop
adds another important tool for expressing complex computations in JAX, allowing you to handle algorithms where the flow depends dynamically on computed values while maintaining the benefits of JAX's compilation and hardware acceleration.
© 2025 ApX Machine Learning