Automatic differentiation in JAX seamlessly integrates with its structured control flow primitives like lax.cond
, lax.while_loop
, and lax.scan
. However, understanding how differentiation interacts with these operations is important for writing efficient code and interpreting results, especially since JAX needs to trace the computation graph before execution.
lax.cond
Recall that lax.cond(pred, true_fun, false_fun, operand)
executes either true_fun(operand)
or false_fun(operand)
based on the boolean value of pred
. When JAX encounters lax.cond
inside a function being differentiated (e.g., using jax.grad
), it must generate code that can compute the gradient regardless of which branch is taken during the actual execution.
To achieve this, JAX traces both the true_fun
and false_fun
branches during the differentiation pass. The resulting differentiated function will contain the logic for the gradients of both branches. During execution, the original forward pass determines which branch's result is used. The backward pass then computes gradients corresponding to the branch that was actually executed.
Consider this example:
import jax
import jax.numpy as jnp
from jax import lax
def conditional_computation(x, y):
# Condition depends on input 'y'
return lax.cond(y > 0,
lambda op: jnp.sin(op * 2.0), # true_fun
lambda op: jnp.cos(op / 2.0), # false_fun
x) # operand
grad_x = jax.grad(conditional_computation, argnums=0)
grad_y = jax.grad(conditional_computation, argnums=1) # Note: Gradient w.r.t. pred is typically zero
# Example execution
x_val = jnp.pi / 4.0
y_val_pos = 1.0
y_val_neg = -1.0
print(f"f(x, y>0) = {conditional_computation(x_val, y_val_pos)}")
print(f"df/dx(x, y>0) = {grad_x(x_val, y_val_pos)}")
# Gradient of sin(2x) is 2*cos(2x). At x=pi/4, 2*cos(pi/2) = 0.0
print(f"\nf(x, y<0) = {conditional_computation(x_val, y_val_neg)}")
print(f"df/dx(x, y<0) = {grad_x(x_val, y_val_neg)}")
# Gradient of cos(x/2) is -0.5*sin(x/2). At x=pi/4, -0.5*sin(pi/8) approx -0.191
# Gradient w.r.t. 'y' (the predicate condition) is generally zero
# because the predicate itself is usually treated as discrete.
print(f"\ndf/dy(x, y>0) = {grad_y(x_val, y_val_pos)}")
print(f"df/dy(x, y<0) = {grad_y(x_val, y_val_neg)}")
Key implications:
true_fun
and false_fun
must have the same structure (shape and dtype), as JAX needs a consistent type signature.pred
itself is usually not meaningful or supported unless the predicate calculation involves differentiable operations and its output is used in a way that allows gradients (which is uncommon for typical boolean conditions). JAX often yields a zero gradient for arguments that only influence the predicate.lax.while_loop
The primitive lax.while_loop(cond_fun, body_fun, init_val)
applies body_fun
repeatedly as long as cond_fun
returns True
. Differentiating a while_loop
involves applying the chain rule back through the iterations performed during the forward pass.
This is analogous to Backpropagation Through Time (BPTT) used for training Recurrent Neural Networks (RNNs). The gradients for the initial state init_val
depend on how the state evolves over all iterations via body_fun
.
import jax
import jax.numpy as jnp
from jax import lax
def loop_sum(max_val):
# Sums numbers from 0 up to (but not including) max_val
init_state = (0, 0.0) # (current_i, current_sum)
def cond_fun(state):
i, _ = state
return i < max_val # Continue while i < max_val
def body_fun(state):
i, current_sum = state
return (i + 1, current_sum + jnp.sqrt(i.astype(jnp.float32))) # Example op
final_state = lax.while_loop(cond_fun, body_fun, init_state)
_, final_sum = final_state
return final_sum
# Differentiate the loop result w.r.t the initial 'max_val' input
# Note: max_val affects the *number* of iterations, making its gradient complex.
# Let's differentiate w.r.t something *inside* the loop for clarity.
def loop_sum_param(scale, n_iters):
init_state = (0, 0.0) # (i, current_sum)
def cond_fun(state):
i, _ = state
return i < n_iters
def body_fun(state):
i, current_sum = state
# Use 'scale' inside the loop
return (i + 1, current_sum + scale * i)
_, final_sum = lax.while_loop(cond_fun, body_fun, init_state)
return final_sum
grad_loop_sum = jax.grad(loop_sum_param, argnums=0) # Gradient w.r.t. 'scale'
scale_val = 2.0
iters = 5
print(f"Loop sum(scale={scale_val}, iters={iters}) = {loop_sum_param(scale_val, iters)}")
# Forward pass: 2*0 + 2*1 + 2*2 + 2*3 + 2*4 = 0 + 2 + 4 + 6 + 8 = 20
# Expected gradient d(sum)/d(scale) = 0 + 1 + 2 + 3 + 4 = 10
print(f"d(Sum)/d(scale) = {grad_loop_sum(scale_val, iters)}")
Key implications:
jax.checkpoint
), which recomputes intermediate values during the backward pass instead of storing them. This trades compute for memory.lax.scan
The lax.scan(f, init, xs)
primitive applies a function f
cumulatively over a sequence xs
, carrying state init
. It's often used for implementing RNNs or other sequential processes where the number of steps is known beforehand.
Differentiation through lax.scan
is well-defined and efficient. JAX handles the propagation of gradients through the carried state (carry
) and the per-step outputs (y
) automatically. Like lax.while_loop
, the backward pass resembles BPTT.
import jax
import jax.numpy as jnp
from jax import lax
def simple_rnn_step(carry, x_t):
# A very basic RNN cell: carry is hidden state h_t-1
# x_t is input at time t
# Output is new hidden state h_t and an output y_t
prev_h = carry
weight_hh = 0.5 # Fixed parameter for simplicity
weight_xh = 1.5 # Fixed parameter for simplicity
# Simple linear update
new_h = jnp.tanh(prev_h * weight_hh + x_t * weight_xh)
y_t = new_h * 2.0 # Some output based on hidden state
return new_h, y_t # new_carry, y_t
def run_rnn(initial_state, inputs):
# initial_state: h_0
# inputs: sequence of x_t values
final_state, outputs_y = lax.scan(simple_rnn_step, initial_state, inputs)
return jnp.sum(outputs_y) # Return sum of outputs for scalar loss
grad_rnn_params = jax.grad(run_rnn, argnums=0) # Grad w.r.t initial_state
grad_rnn_inputs = jax.grad(run_rnn, argnums=1) # Grad w.r.t inputs
h0 = jnp.zeros(()) # Initial hidden state (scalar)
xts = jnp.array([0.1, 0.2, -0.1, 0.3]) # Sequence of inputs
total_output = run_rnn(h0, xts)
print(f"Total output = {total_output}")
# Gradients w.r.t initial state and inputs
dh0 = grad_rnn_params(h0, xts)
dxts = grad_rnn_inputs(h0, xts)
print(f"d(Sum)/dh0 = {dh0}")
print(f"d(Sum)/dxts = {dxts}")
Key implications:
lax.scan
is generally preferred over lax.while_loop
for sequences of fixed length because its structure is more explicit, often leading to simpler analysis and potentially more optimization by XLA.scan
efficiently computes gradients by iterating backward through the steps performed in the forward pass, using the saved intermediate values.jax.checkpoint
can also be applied to the function f
used within scan
.Data flow for differentiating
lax.scan
. The forward pass computes states (h) and outputs (y) sequentially. The backward pass (VJP) propagates gradients (dL/d⋅) backward through the unrolled computation using the intermediate values from the forward pass.
In summary, JAX's automatic differentiation system is designed to work correctly with structured control flow primitives. While the differentiation "just works" from a user perspective, knowing that cond
traces both branches and that while_loop
and scan
differentiation resembles BPTT helps in understanding memory usage, potential numerical issues, and the applicability of techniques like gradient checkpointing.
© 2025 ApX Machine Learning