趋近智
虽然 jit、grad 和 vmap 为加速和转换数值函数提供了有效的工具,但它们主要一次性处理整个数组或批次。许多重要算法,尤其是在诸如序列建模、信号处理和优化等方面,包含序列依赖性,一步的结果依赖于前一步的输出。在 JAX 中高效实现这些需要一个专用工具:jax.lax.scan。
可以想象处理时间序列、逐步模拟物理系统,或者实现循环神经网络(RNN)的前向传播。一个简单的 Python for 循环迭代这些步骤在纯 Python 中可行,但它对 JAX 的编译模型提出了挑战。当 JAX 跟踪一个包含 Python 循环且其长度取决于运行时值的函数时,它在编译时通常需要“展开”循环。这意味着在编译图中为每次迭代复制循环体的操作。对于长序列,这种展开会导致非常大的计算图,显著增加编译时间并可能超出内存限制。
lax.scan 提供了一个函数式替代方案,专门为这些情况设计。它允许你表达序列计算,JAX 可以将其编译成在 GPU 和 TPU 等加速器上的高效循环原语,避免了在 jit 编译函数中使用显式 Python 循环的弊端。
其核心是,lax.scan 通过重复应用你定义的函数来工作。这个函数在每一步接收两个参数:
carry 状态:这保存着从上一步传递到当前步的信息。这是在整个序列处理过程中保持状态的方式。xs 的当前切片 x:这是当前步骤的特定输入(可选)。该函数必须返回一个包含以下内容的元组:
carry 状态:这将传递给下一步。y:这会在所有步骤中收集,形成最终的输出序列。lax.scan 的整体签名是:
jax.lax.scan(f, init, xs, length=None, reverse=False, unroll=1)
让我们分解一下主要参数:
f:要重复应用的函数。它必须具有签名 f(carry, x) -> (new_carry, y)。如果 xs 为 None,则签名为 f(carry, None) -> (new_carry, y) 或者甚至 f(carry) -> (new_carry, y)(隐式接受 None)。init:第一步之前的 carry 状态的初始值。其结构必须与 f 输出的 carry 部分匹配。xs:一个可选的 PyTree(例如,数组、数组的元组/列表/字典),表示输入序列。lax.scan 遍历 xs 中数组的主轴。传递给 f 的每个切片 x 将具有与 xs 匹配的结构。如果 xs 为 None,scan 将迭代 length 次,不带每步输入。length:可选整数,指定步数。通常从 xs 的主轴维度推断。如果 xs 为 None,则必需。reverse:可选布尔值。如果为 True,则从 xs 的末尾扫描到开头。默认为 False。unroll:可选整数。用于性能调优,向编译器建议要展开多少次循环迭代。默认为 1。更大的值有时可以通过减少循环开销来提高某些硬件的性能,但会增加编译时间和代码大小。请谨慎使用并进行性能分析。lax.scan 返回一个元组 (final_carry, ys),final_carry 是最后一步后的 carry 状态,而 ys 是一个 PyTree,包含每一步堆叠的输出 y。ys 的结构与 f 输出的 y 部分匹配,但增加了一个与序列长度对应的主维度。
让我们看看如何实现累加和,其中每个元素是输入数组中所有前面元素(包括其自身)的总和。
import jax
import jax.numpy as jnp
import jax.lax as lax
# 定义扫描函数:f(carry, x) -> (new_carry, y)
# carry: 到前一个元素的总和
# x: 当前元素
# new_carry: 包含当前元素的总和 (carry + x)
# y: 此步骤的输出 (也是 carry + x)
def cumulative_sum_step(carry_sum, current_x):
new_sum = carry_sum + current_x
return new_sum, new_sum # 返回新的 carry 和此步骤的输出
# 输入数组
input_array = jnp.array([1, 2, 3, 4, 5])
# 初始 carry 为 0(第一个元素之前的和)
initial_carry = 0
# 应用 lax.scan
final_carry, result_sequence = lax.scan(cumulative_sum_step, initial_carry, input_array)
print("输入数组:", input_array)
print("初始 Carry:", initial_carry)
print("最终 Carry(总和):", final_carry)
print("结果序列(累加和):", result_sequence)
# 预期输出:
# 输入数组: [1 2 3 4 5]
# 初始 Carry: 0
# 最终 Carry(总和): 15
# 结果序列(累加和): [ 1 3 6 10 15]
在这个示例中:
initial_carry 从 0 开始。f(0, 1) 返回 (1, 1)。carry 变为 1,ys 开始收集 [1]。f(1, 2) 返回 (3, 3)。carry 变为 3,ys 是 [1, 3]。f(3, 3) 返回 (6, 6)。carry 变为 6,ys 是 [1, 3, 6]。f(6, 4) 返回 (10, 10)。carry 变为 10,ys 是 [1, 3, 6, 10]。f(10, 5) 返回 (15, 15)。carry 变为 15,ys 是 [1, 3, 6, 10, 15]。lax.scan 返回 final_carry (15) 和收集到的 ys 序列 ([1, 3, 6, 10, 15])。一个更实际的应用是实现一个 RNN。一个简单的 RNN 根据前一个隐藏状态 ht−1 和当前输入 xt 来更新其隐藏状态 ht:
ht=tanh(Whhht−1+Wxhxt+bh)我们可以使用 lax.scan 对输入序列迭代执行此更新步骤。carry 将是隐藏状态 h,而 xs 将是输入序列 x。
import jax
import jax.numpy as jnp
import jax.lax as lax
import jax.random as random
key = random.PRNGKey(0)
# 定义维度
input_features = 3
hidden_features = 5
sequence_length = 4
# 初始化参数(作为 carry 的一部分或闭包)
key, w_key, b_key = random.split(key, 3)
W_hh = random.normal(w_key, (hidden_features, hidden_features)) * 0.1
W_xh = random.normal(w_key, (hidden_features, input_features)) * 0.1
b_h = random.normal(b_key, (hidden_features,)) * 0.1
# 生成一个虚拟输入序列 (sequence_length, input_features)
x_key = random.split(key)
input_sequence = random.normal(x_key, (sequence_length, input_features))
# 初始隐藏状态
h_initial = jnp.zeros((hidden_features,))
# 定义 RNN 步骤函数:f(h_prev, x_t) -> (h_new, h_new)
# 我们也将 h_new 作为每步结果 'y' 输出
def rnn_step(h_prev, x_t):
h_new = jnp.tanh(jnp.dot(W_hh, h_prev) + jnp.dot(W_xh, x_t) + b_h)
# 返回新状态,作为下一个 carry 和此步骤的输出
return h_new, h_new
# 应用 lax.scan
final_hidden_state, hidden_states_sequence = lax.scan(rnn_step, h_initial, input_sequence)
print("输入序列形状:", input_sequence.shape)
print("初始隐藏状态形状:", h_initial.shape)
print("最终隐藏状态形状:", final_hidden_state.shape)
print("隐藏状态序列形状:", hidden_states_sequence.shape)
# 预期输出形状:
# 输入序列形状: (4, 3)
# 初始隐藏状态形状: (5,)
# 最终隐藏状态形状: (5,)
# 隐藏状态序列形状: (4, 5)
这里,rnn_step 封装了核心的循环关系。lax.scan 从 h_initial 开始,高效地将此函数应用于 input_sequence。carry 完美地表示了演变的隐藏状态 ht,而收集到的 ys(这里是 hidden_states_sequence)则提供了每个时间步的隐藏状态。
请注意,参数 W_hh、W_xh 和 b_h 是在 rnn_step 外部定义的。JAX 在跟踪和编译时会自动处理这些闭包。
lax.scan 为何表现出色:性能与编译lax.scan 相较于 @jit 内部的 Python for 循环的主要优势在于性能,尤其是在加速器上。
lax.scan 翻译为专门的 XLA HLO(高级优化器)循环操作(例如 While)。XLA 随后可以非常高效地为目标硬件(GPU/TPU)编译这种循环表示,而无需为每次迭代复制计算图。这使得编译时间可控,并减小了编译后的代码大小。lax.scan 的编译表示通常紧凑得多。此外,在执行过程中,XLA 可以优化循环内的内存使用,可能比完全展开的图更有效地重用缓冲区。scan 表示的循环体和控制流应用复杂的分析和转换,从而加快执行速度。lax.scan 与其他转换lax.scan 旨在与其他 JAX 转换结合使用:
jit: 如前所述,lax.scan 非常适合在 @jit 内部使用。它提供了 XLA 高效编译所需的结构。grad: JAX 可以自动对 lax.scan 进行求导。它计算相对于 init 状态、由 f 闭包捕获的参数以及输入序列 xs 的梯度。这对于训练 RNN 和其他序列模型非常重要。我们将在第 4 章中更详细地了解控制流的求导。vmap: 你可以使用 vmap 并行运行多个独立的扫描,例如,同时处理一批序列。如果你有一个形状为 (batch_size, sequence_length, input_features) 的 batch_input_sequences,你可以用 vmap 封装 lax.scan:jax.vmap(lambda seq: lax.scan(rnn_step, h_initial, seq))(batch_input_sequences)。(注意:正确处理批处理初始状态可能还需要对 h_initial 进行 vmap 操作)。xs=None)有时,你可能希望仅基于 carry 状态来迭代函数,而不必在每一步消耗外部输入序列。这对于生成序列或运行固定步数的迭代过程很有用。你可以通过传递 xs=None 并指定 length 参数来实现这一点。
import jax
import jax.numpy as jnp
import jax.lax as lax
# 生成前 10 个 2 的幂:1, 2, 4, ...
# carry: 前一个 2 的幂
# x: 为 None,被忽略
# new_carry: 下一个 2 的幂 (carry * 2)
# y: 当前 2 的幂 (carry)
def generate_powers_of_2(carry, _): # 使用 _ 表示 x 未被使用
next_val = carry * 2
return next_val, carry # 返回下一个 carry,输出当前值
initial_value = 1
num_steps = 10
final_val, powers_of_2 = lax.scan(generate_powers_of_2,
initial_value,
xs=None, # 无输入序列
length=num_steps)
print("最终值 (2^10):", final_val)
print("生成的 2 的幂:", powers_of_2)
# 预期输出:
# 最终值 (2^10): 1024
# 生成的 2 的幂: [ 1 2 4 8 16 32 64 128 256 512]
总而言之,lax.scan 是你高级 JAX 工具包中不可或缺的工具。它提供了表达复杂序列和循环计算的方式,以一种在用 jit 编译时既函数式优雅又高性能的方式,为在加速硬件上实现许多复杂模型和算法提供了方法。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造