While lax.scan excels at executing a fixed number of sequential steps, many algorithms require loops that continue until a specific condition is met. Standard Python while loops present a challenge for JAX's compilation process. Because JAX traces the function to generate a computation graph before execution, it needs to know the structure of the computation upfront. A Python while loop whose termination condition depends on intermediate values computed within the loop cannot be directly traced and JIT-compiled effectively.
To handle such dynamic looping structures within compiled code, JAX provides jax.lax.while_loop. This function allows you to express loops where the number of iterations is determined dynamically based on runtime values, while still enabling JIT compilation and execution on accelerators.
lax.while_loopThe lax.while_loop function takes three main arguments:
cond_fun: A Python callable that takes the current loop state (the "carry") and returns a boolean scalar JAX value. The loop continues as long as cond_fun returns True.body_fun: A Python callable that defines the operations performed in a single iteration. It takes the current loop state (carry) as input and must return an updated loop state with the same structure (shape and dtype).init_val: The initial state or "carry" value provided to the first iteration of the loop.The basic signature looks like this:
final_val = jax.lax.while_loop(cond_fun, body_fun, init_val)
The loop proceeds as follows:
cond_fun is called with init_val.True, body_fun is called with init_val. It returns next_val.cond_fun is called with next_val.True, body_fun is called with next_val. It returns another_val.cond_fun returns False.body_fun (the one passed to cond_fun which resulted in False) is the final result returned by lax.while_loop.Crucially, like other JAX transformations and control flow primitives, lax.while_loop operates functionally. The loop state is not modified in place; instead, the body_fun must explicitly return the new state for the next iteration. This state (init_val and subsequent outputs of body_fun) can be any JAX-compatible type, including scalars, arrays, or pytrees (nested tuples, lists, dictionaries) of arrays. The structure and types/shapes of the state must remain consistent across iterations.
Let's illustrate with a simple example: finding the smallest power of 2 that is greater than or equal to a given number n.
import jax
import jax.numpy as jnp
def find_power_of_two(n):
# Condition function: continue while current_power < n
def cond_fun(loop_state):
current_power = loop_state
return current_power < n
# Body function: double the current power
def body_fun(loop_state):
current_power = loop_state
return current_power * 2
# Initial state: start with power = 1
init_val = 1
# Run the while loop
final_power = jax.lax.while_loop(cond_fun, body_fun, init_val)
return final_power
# Example usage
target_number = 100
result = find_power_of_two(target_number)
print(f"Smallest power of 2 >= {target_number}: {result}")
# Output: Smallest power of 2 >= 100: 128
# We can JIT-compile the function
jit_find_power_of_two = jax.jit(find_power_of_two)
result_jit = jit_find_power_of_two(target_number)
print(f"JIT result: {result_jit}")
# Output: JIT result: 128
In this example:
init_val is 1.cond_fun checks if the current power (loop_state) is less than n (100).body_fun takes the current power and returns the doubled value as the new state.cond_fun(1) -> True. body_fun(1) -> 2.cond_fun(2) -> True. body_fun(2) -> 4.cond_fun(64) -> True. body_fun(64) -> 128.cond_fun(128) -> False. Loop terminates.lax.while_loop returns the last state passed to cond_fun that caused it to return False, which is 128.When JAX encounters lax.while_loop inside a function being JIT-compiled, it traces both cond_fun and body_fun once to understand the operations they perform. The shapes and dtypes of the loop state (init_val and the return value of body_fun) must be static and determinable during tracing. The values can change dynamically, but the structure cannot.
The entire loop is then compiled into a single optimized operation (often a while operation in the underlying XLA representation). This is fundamentally different from a Python while loop, which would typically cause JAX to either unroll the loop if the number of iterations is fixed and known during tracing or fail compilation if the condition depends on traced values. lax.while_loop explicitly tells JAX that this is a dynamic loop structure intended for compilation.
lax.scanIt's helpful to contrast lax.while_loop with lax.scan:
lax.scan executes a fixed number of iterations, determined by the length of the input sequence(s). lax.while_loop executes a variable number of iterations until cond_fun returns False.lax.scan typically accumulates outputs from each step. lax.while_loop only returns the final state of the loop carry.lax.scan is ideal for processing sequences (like in RNNs) or applying an operation repeatedly for a known count. lax.while_loop is suited for iterative algorithms with convergence criteria or simulations running until a condition is met.cond_fun will eventually become False. An infinite loop in lax.while_loop will cause the compiled program to hang. JAX cannot statically verify termination for all possible loops.lax.while_loop enables dynamic iteration counts, loops with very large or highly variable iteration counts might have performance implications compared to statically sized operations or lax.scan where the total amount of work is known at compile time. Loop overhead exists, though XLA optimizes it well.lax.while_loop supports automatic differentiation (jax.grad). Differentiating through a while_loop typically involves unrolling the loop iterations during the backward pass, which can have memory implications depending on the number of iterations executed. This interaction is explored further in Chapter 4.lax.while_loop adds another important tool for expressing complex computations in JAX, allowing you to handle algorithms where the flow depends dynamically on computed values while maintaining the benefits of JAX's compilation and hardware acceleration.
Was this section helpful?
jax.lax.while_loop.lax primitives like while_loop for JIT-compiled code.© 2026 ApX Machine LearningEngineered with