Now that we've examined the core control flow primitives like lax.scan
, let's apply them to a practical scenario often encountered in sequence modeling: implementing a custom recurrent cell. While basic RNNs are straightforward, many advanced architectures use more complex gating mechanisms. We'll implement a simplified Gated Recurrent Unit (GRU) cell using lax.scan
to manage the sequential processing and hidden state updates. This exercise demonstrates how to structure non-trivial computations within the scan body.
A GRU is a type of recurrent neural network cell designed to capture dependencies over different time scales by using gating mechanisms. These gates control the flow of information, deciding what to keep from the past state and what to incorporate from the new input.
For our example, we'll implement a slightly simplified version. Let xt be the input vector at timestep t, and ht−1 be the hidden state from the previous timestep. The hidden state ht at timestep t is computed as follows:
Here, σ represents the sigmoid activation function, tanh is the hyperbolic tangent activation function, and ⊙ denotes element-wise multiplication. Wz,Uz,bz,Wr,Ur,br,Wh,Uh,bh are the learnable parameters (weight matrices and bias vectors) of the GRU cell.
We can implement the processing of an entire sequence through the GRU cell using lax.scan
. The core idea is:
carry
in lax.scan
will hold the hidden state ht at each timestep.xs
argument will be the input sequence (x1,x2,...,xT).lax.scan
(let's call it gru_step
) will implement the four equations above, taking the previous hidden state h_prev
(from the carry) and the current input x_t
(from xs
) to compute the new hidden state h_t
.gru_step
will return (h_t, h_t)
. The first h_t
becomes the carry
for the next step, and the second h_t
is accumulated as the output sequence.Let's write the code. We'll start with imports and defining the parameters. In a real scenario, these parameters would be part of a larger model structure (like Flax or Haiku), but here we'll define them directly for clarity.
import jax
import jax.numpy as jnp
import jax.lax as lax
from jax import random
# Define activation functions
sigmoid = jax.nn.sigmoid
tanh = jnp.tanh
def initialize_gru_params(key, input_dim, hidden_dim):
"""Initializes parameters for the simplified GRU cell."""
keys = random.split(key, 6) # Need keys for Wz, Uz, bz, Wr, Ur, br, Wh, Uh, bh (3 pairs of W,U + 3 biases)
# Update gate parameters
Wz = random.normal(keys[0], (hidden_dim, input_dim)) * 0.01
Uz = random.normal(keys[1], (hidden_dim, hidden_dim)) * 0.01
bz = jnp.zeros((hidden_dim,))
# Reset gate parameters
Wr = random.normal(keys[2], (hidden_dim, input_dim)) * 0.01
Ur = random.normal(keys[3], (hidden_dim, hidden_dim)) * 0.01
br = jnp.zeros((hidden_dim,))
# Candidate hidden state parameters
Wh = random.normal(keys[4], (hidden_dim, input_dim)) * 0.01
Uh = random.normal(keys[5], (hidden_dim, hidden_dim)) * 0.01
bh = jnp.zeros((hidden_dim,))
params = {
'Wz': Wz, 'Uz': Uz, 'bz': bz,
'Wr': Wr, 'Ur': Ur, 'br': br,
'Wh': Wh, 'Uh': Uh, 'bh': bh
}
return params
def gru_step(params, h_prev, x_t):
"""Performs one step of the simplified GRU computation."""
# Update gate
z_t = sigmoid(jnp.dot(params['Wz'], x_t) + jnp.dot(params['Uz'], h_prev) + params['bz'])
# Reset gate
r_t = sigmoid(jnp.dot(params['Wr'], x_t) + jnp.dot(params['Ur'], h_prev) + params['br'])
# Candidate hidden state
h_tilde_t = tanh(jnp.dot(params['Wh'], x_t) + jnp.dot(params['Uh'], (r_t * h_prev)) + params['bh'])
# Final hidden state
h_t = (1.0 - z_t) * h_prev + z_t * h_tilde_t
# Return new hidden state as carry and output
return h_t, h_t
def gru_sequence(params, initial_h, inputs):
"""Applies the GRU cell over a sequence of inputs using lax.scan."""
# Define the scan function, closing over the parameters
scan_fn = lambda carry, x: gru_step(params, carry, x)
# Apply lax.scan
final_h, outputs_h = lax.scan(scan_fn, initial_h, inputs)
# final_h contains the last hidden state
# outputs_h contains the sequence of hidden states [h_1, h_2, ..., h_T]
return final_h, outputs_h
# Example Usage
key = random.PRNGKey(0)
seq_len = 10
input_dim = 5
hidden_dim = 8
# Initialize parameters
gru_params = initialize_gru_params(key, input_dim, hidden_dim)
# Create dummy input sequence (sequence_length, input_features)
key, subkey = random.split(key)
input_sequence = random.normal(subkey, (seq_len, input_dim))
# Initialize hidden state
initial_hidden_state = jnp.zeros((hidden_dim,))
# Run the GRU over the sequence
final_state, hidden_states_sequence = gru_sequence(gru_params, initial_hidden_state, input_sequence)
print("Input sequence shape:", input_sequence.shape)
print("Initial hidden state shape:", initial_hidden_state.shape)
print("Final hidden state shape:", final_state.shape)
print("Output sequence of hidden states shape:", hidden_states_sequence.shape)
In this code:
initialize_gru_params
sets up the necessary weight matrices and bias vectors with appropriate shapes, using small random values for initialization.gru_step
implements the core logic for a single timestep. It takes the parameters, the previous hidden state h_prev
, and the current input x_t
, returning the new hidden state h_t
twice (once as the new carry, once as the output for this step).gru_sequence
orchestrates the process. It defines scan_fn
which is just gru_step
with the params
argument fixed (closed over). It then calls lax.scan
with this function, the initial hidden state, and the input sequence.gru_sequence
function. The output shapes confirm that the final state has the hidden dimension, and the output sequence has dimensions (sequence_length, hidden_dimension)
.One of the advantages of using lax.scan
is that the resulting gru_sequence
function is fully compatible with other JAX transformations like jit
, grad
, and vmap
.
For instance, to compile the GRU computation for faster execution, simply wrap the call with jax.jit
:
# Compile the GRU function for efficiency
jit_gru_sequence = jax.jit(gru_sequence)
# Run the compiled version (first run includes compilation time)
key, subkey = random.split(key)
input_sequence_2 = random.normal(subkey, (seq_len, input_dim))
final_state_jit, hidden_states_sequence_jit = jit_gru_sequence(gru_params, initial_hidden_state, input_sequence_2)
print("\nRunning JIT-compiled version:")
print("Final hidden state shape (JIT):", final_state_jit.shape)
print("Output sequence shape (JIT):", hidden_states_sequence_jit.shape)
If you wanted to process a batch of sequences simultaneously, you could use jax.vmap
. Assuming your inputs have a batch dimension, like (batch_size, seq_len, input_dim)
, you would map over the batch dimension for both the initial hidden state (batch_size, hidden_dim
) and the inputs:
# Example VMAP usage (conceptual - requires batched inputs/states)
# Assume:
# batch_size = 32
# batched_inputs = random.normal(key, (batch_size, seq_len, input_dim))
# batched_initial_h = jnp.zeros((batch_size, hidden_dim,))
# Map over the batch dimension (axis 0 for params=None, initial_h, inputs)
# Note: params are shared across the batch, so we use None in in_axes
# batched_gru = jax.vmap(gru_sequence, in_axes=(None, 0, 0))
# final_states_batch, hidden_sequences_batch = batched_gru(gru_params, batched_initial_h, batched_inputs)
# print("Batched final state shape:", final_states_batch.shape) # (batch_size, hidden_dim)
# print("Batched output sequences shape:", hidden_sequences_batch.shape) # (batch_size, seq_len, hidden_dim)
Similarly, you could compute gradients with respect to the parameters (gru_params
) or the inputs (input_sequence
) using jax.grad
, enabling the training of the GRU cell within a larger model.
This example illustrates how lax.scan
provides a powerful mechanism for implementing complex, stateful sequential computations in a way that integrates cleanly with JAX's compilation and automatic differentiation capabilities. By defining the logic for a single step and letting lax.scan
handle the iteration, you can build sophisticated recurrent models efficiently.
© 2025 ApX Machine Learning