When working with batched data or sequences, you often encounter situations where elements are not uniform. For instance, sentences in a batch have varying lengths, or you might only want to apply an operation to certain elements based on a condition. While control flow primitives like lax.cond
can handle some conditional logic, applying operations element-wise based on a condition across large arrays often calls for masking.
Masking involves using auxiliary arrays, typically boolean, to specify which elements of your data arrays should be included in a computation and which should be ignored or treated differently. This approach is particularly effective within jit
-compiled functions because it often translates to efficient, branchless element-wise operations on hardware accelerators.
Consider processing a batch of sentences represented as sequences of word embeddings. To use vmap
or feed them into many standard neural network layers, these sequences typically need to be padded to the same maximum length. However, you don't want the padding tokens to influence the results of operations like calculating sequence averages or attention scores. Masking provides a way to systematically exclude these padding elements.
Similarly, within a lax.scan
loop processing sequences, you might reach the end of some sequences before others. Masking allows you to selectively update states or compute outputs only for the sequences that are still active.
The most frequent tool for masking in JAX is jax.numpy.where
. Its signature is:
jnp.where(condition, x, y)
It returns an array with elements selected from x
where the condition
array is True
, and elements from y
where the condition
array is False
. condition
, x
, and y
are broadcast together.
This is a primary use case. Let's illustrate with padding variable-length sequences.
True
indicates a valid data element and False
indicates a padding element.jnp.where
or arithmetic operations to apply the mask.import jax
import jax.numpy as jnp
# Example: Batch of sequences (represented by integers for simplicity)
sequences = [
jnp.array([1, 2, 3]),
jnp.array([4, 5]),
jnp.array([6, 7, 8, 9])
]
# Assume a padding function (implementation omitted for brevity)
# that pads with 0 to the max length (4)
padded_sequences = jnp.array([
[1, 2, 3, 0],
[4, 5, 0, 0],
[6, 7, 8, 9]
])
# Padded data shape: (3, 4) - Batch size 3, Max length 4
# Create the mask: True for data, False for padding
# Assuming padding value is 0
mask = (padded_sequences != 0)
# mask:
# [[ True, True, True, False],
# [ True, True, False, False],
# [ True, True, True, True]]
# Example 1: Masked Sum (avoid summing padding)
# Use jnp.where to replace padding with 0 before summing
masked_values_for_sum = jnp.where(mask, padded_sequences, 0)
sum_per_sequence = jnp.sum(masked_values_for_sum, axis=-1)
# sum_per_sequence: [ 6 9 30] (Correct sums: 1+2+3=6, 4+5=9, 6+7+8+9=30)
# Example 2: Masked Average (avoid counting padding in the denominator)
# Sum as above
sum_values = jnp.sum(jnp.where(mask, padded_sequences, 0.0), axis=-1)
# Count valid elements per sequence
num_valid_elements = jnp.sum(mask, axis=-1)
# Avoid division by zero for sequences that might be entirely padding (if possible)
average_per_sequence = sum_values / jnp.maximum(num_valid_elements, 1)
# average_per_sequence: [ 2. , 4.5 , 7.5 ]
# Example 3: Applying mask directly (often used in attention)
# Imagine scores calculated for all positions, including padding.
# Set scores for padding positions to a large negative number before softmax.
scores = jnp.randn(3, 4) # Example scores
masked_scores = jnp.where(mask, scores, -1e9) # Use large negative value
# Now apply softmax, padding scores will be near zero.
attention_weights = jax.nn.softmax(masked_scores, axis=-1)
# print("Padded Sequences:\n", padded_sequences)
# print("Mask:\n", mask)
# print("Sum per sequence:", sum_per_sequence)
# print("Average per sequence:", average_per_sequence)
# print("Masked scores for softmax (example):\n", masked_scores)
# print("Attention weights (example):\n", attention_weights)
Visual representation of padding sequences of varying lengths to a uniform length and generating the corresponding boolean mask.
T
stands for True (valid data) andF
for False (padding).
Instead of jnp.where
, you can sometimes use arithmetic, especially when the mask is composed of 0s and 1s. Casting the boolean mask to the data's dtype achieves this.
# Cast mask to float (True -> 1.0, False -> 0.0)
mask_float = mask.astype(padded_sequences.dtype)
# Masked sum using multiplication
sum_per_sequence_alt = jnp.sum(padded_sequences * mask_float, axis=-1)
# sum_per_sequence_alt: [ 6. 9. 30.] (Ensure dtype matches, usually float)
# Masked average using multiplication
num_valid_elements_alt = jnp.sum(mask_float, axis=-1)
average_per_sequence_alt = sum_per_sequence_alt / jnp.maximum(num_valid_elements_alt, 1e-9) # Add epsilon for safety
# average_per_sequence_alt: [ 2. , 4.5 , 7.5 ]
# print("Mask float:\n", mask_float)
# print("Sum via multiplication:", sum_per_sequence_alt)
# print("Average via multiplication:", average_per_sequence_alt)
Be mindful when using multiplication for masking, particularly in log-space computations or when gradients are involved, as multiplying by zero might not always produce the desired mathematical or gradient behavior compared to explicitly selecting values with jnp.where
.
Masking interacts predictably with JAX transformations and control flow:
jit
: Masking operations like jnp.where
or arithmetic are readily compiled by jit
into efficient low-level code.vmap
: If you vmap
a function over a batch dimension, and your inputs include padded data and corresponding masks, vmap
will automatically vectorize the masking operations alongside the main computation.grad
: Automatic differentiation works correctly through jnp.where
. The gradient will flow back through the branch (x
or y
) that was selected by the condition for each element. Gradients associated with the non-selected branch are effectively zero for that element. This is usually the desired behavior, preventing padding or masked-out elements from contributing to parameter updates.lax.scan
: Masks can be carried as part of the state (carry
) in lax.scan
or computed within the scan body. This allows you to perform sequential operations while respecting the validity boundaries of each sequence in a batch. For example, in an RNN, you could use a mask to prevent updating the hidden state for sequences that have already ended (hit padding).# Conceptual example of masking within lax.scan
def scan_body(carry, x):
hidden_state, current_mask = carry
input_element = x
# Compute potential new state (simplified)
potential_new_state = hidden_state * 0.9 + input_element * 0.1
# Update state only if the mask is True for this step
# Otherwise, keep the old state
new_state = jnp.where(current_mask, potential_new_state, hidden_state)
# Output something based on the new state (maybe masked)
output = jnp.where(current_mask, new_state * 2.0, 0.0)
# Note: How the mask evolves depends on the application.
# Here we assume the mask itself isn't changed by the scan body,
# but it could be part of the input 'x' or updated based on state.
new_carry = (new_state, current_mask)
return new_carry, output
# initial_state = ...
# sequence_inputs = ...
# sequence_masks = ... # Shape matching sequence_inputs
# Need to structure masks correctly for scan, potentially stacking them
# or passing them as part of 'xs' in lax.scan(scan_body, init, xs=(inputs, masks))
# final_carry, outputs = lax.scan(scan_body, (initial_state, initial_mask_state), inputs_with_masks)
lax.cond
: Masking applies a computation everywhere and then selects the result using jnp.where
or arithmetic. This is often efficient on GPUs/TPUs because it avoids data-dependent branching within a SIMD/SIMT execution unit, which can cause threads/lanes to diverge. However, it means you compute results even for the masked-out elements. If the branches involve significantly different and expensive computations, lax.cond
might be faster if the condition effectively skips a heavy computation often, despite potential divergence costs. Profiling is the best way to determine the optimal approach for a specific case.jnp.where
is generally safe. If using arithmetic masking, double-check that multiplying by zero provides the correct gradient behavior (it usually does for simple sums/averages, but can be tricky elsewhere). jax.lax.stop_gradient
can be used explicitly if needed, but rely on jnp.where
first.Mastering masking techniques is important for handling the irregularities common in real-world data, especially when aiming for high performance on accelerators using JAX's compilation and vectorization capabilities. It allows you to write clean, composable code that correctly processes batches and sequences of varying structures.
© 2025 ApX Machine Learning