趋近智
实现一个自定义的循环单元是序列建模中经常遇到的实际场景,这使得它成为像 lax.scan 这样的控制流原语的一个良好应用。尽管基本的 RNN 简单明了,但许多高级架构使用更复杂的门控机制。可以使用 lax.scan 来实现一个简化的门控循环单元(GRU)以处理序列并更新隐藏状态。这说明了如何在 scan 主体中构建非简单的计算。
GRU 是一种循环神经网络单元,它旨在通过门控机制捕获不同时间尺度上的依赖关系。这些门控制着信息的流动,决定保留过去状态的多少,以及从新输入中引入多少。
在这个例子中,我们将实现一个稍加简化的版本。设 xt 为时间步 t 的输入向量,ht−1 为前一时间步的隐藏状态。时间步 t 的隐藏状态 ht 的计算方式如下:
这里,σ 表示 sigmoid 激活函数,tanh 是双曲正切激活函数,⊙ 表示按元素乘法。Wz,Uz,bz,Wr,Ur,br,Wh,Uh,bh 是 GRU 单元的可学习参数(权重矩阵和偏置向量)。
我们可以使用 lax.scan 来实现 GRU 单元对整个序列的处理。其主要思想是:
lax.scan 中的 carry 将在每个时间步保存隐藏状态 ht。xs 参数将是输入序列 (x1,x2,...,xT)。lax.scan 的函数(我们称之为 gru_step)将实现上述四个方程,它接受前一个隐藏状态 h_prev(来自 carry)和当前输入 x_t(来自 xs)来计算新的隐藏状态 h_t。gru_step 将返回 (h_t, h_t)。第一个 h_t 成为下一步的 carry,第二个 h_t 作为输出序列被累积。我们来编写代码。首先是导入库和定义参数。在实际应用中,这些参数将是更大模型结构(如 Flax 或 Haiku)的一部分,但为了清晰起见,我们在这里直接定义它们。
import jax
import jax.numpy as jnp
import jax.lax as lax
from jax import random
# 定义激活函数
sigmoid = jax.nn.sigmoid
tanh = jnp.tanh
def initialize_gru_params(key, input_dim, hidden_dim):
"""初始化简化 GRU 单元的参数。"""
keys = random.split(key, 6) # 需要用于 Wz, Uz, bz, Wr, Ur, br, Wh, Uh, bh 的键(3 对 W,U + 3 个偏置)
# 更新门参数
Wz = random.normal(keys[0], (hidden_dim, input_dim)) * 0.01
Uz = random.normal(keys[1], (hidden_dim, hidden_dim)) * 0.01
bz = jnp.zeros((hidden_dim,))
# 重置门参数
Wr = random.normal(keys[2], (hidden_dim, input_dim)) * 0.01
Ur = random.normal(keys[3], (hidden_dim, hidden_dim)) * 0.01
br = jnp.zeros((hidden_dim,))
# 候选隐藏状态参数
Wh = random.normal(keys[4], (hidden_dim, input_dim)) * 0.01
Uh = random.normal(keys[5], (hidden_dim, hidden_dim)) * 0.01
bh = jnp.zeros((hidden_dim,))
params = {
'Wz': Wz, 'Uz': Uz, 'bz': bz,
'Wr': Wr, 'Ur': Ur, 'br': br,
'Wh': Wh, 'Uh': Uh, 'bh': bh
}
return params
def gru_step(params, h_prev, x_t):
"""执行简化 GRU 计算的一个步骤。"""
# 更新门
z_t = sigmoid(jnp.dot(params['Wz'], x_t) + jnp.dot(params['Uz'], h_prev) + params['bz'])
# 重置门
r_t = sigmoid(jnp.dot(params['Wr'], x_t) + jnp.dot(params['Ur'], h_prev) + params['br'])
# 候选隐藏状态
h_tilde_t = tanh(jnp.dot(params['Wh'], x_t) + jnp.dot(params['Uh'], (r_t * h_prev)) + params['bh'])
# 最终隐藏状态
h_t = (1.0 - z_t) * h_prev + z_t * h_tilde_t
# 返回新的隐藏状态作为 carry 和输出
return h_t, h_t
def gru_sequence(params, initial_h, inputs):
"""使用 lax.scan 将 GRU 单元应用于一系列输入。"""
# 定义扫描函数,闭包捕获参数
scan_fn = lambda carry, x: gru_step(params, carry, x)
# 应用 lax.scan
final_h, outputs_h = lax.scan(scan_fn, initial_h, inputs)
# final_h 包含最后一个隐藏状态
# outputs_h 包含隐藏状态序列 [h_1, h_2, ..., h_T]
return final_h, outputs_h
# 使用示例
key = random.PRNGKey(0)
seq_len = 10
input_dim = 5
hidden_dim = 8
# 初始化参数
gru_params = initialize_gru_params(key, input_dim, hidden_dim)
# 创建虚拟输入序列(序列长度,输入特征)
key, subkey = random.split(key)
input_sequence = random.normal(subkey, (seq_len, input_dim))
# 初始化隐藏状态
initial_hidden_state = jnp.zeros((hidden_dim,))
# 在序列上运行 GRU
final_state, hidden_states_sequence = gru_sequence(gru_params, initial_hidden_state, input_sequence)
print("输入序列形状:", input_sequence.shape)
print("初始隐藏状态形状:", initial_hidden_state.shape)
print("最终隐藏状态形状:", final_state.shape)
print("隐藏状态输出序列形状:", hidden_states_sequence.shape)
在这段代码中:
initialize_gru_params 使用小的随机值初始化,设置了必要的权重矩阵和偏置向量,使其具有适当的形状。gru_step 实现了一个时间步的核心逻辑。它接受参数、前一个隐藏状态 h_prev 和当前输入 x_t,并两次返回新的隐藏状态 h_t(一次作为新的 carry,一次作为该步骤的输出)。gru_sequence 协调整个过程。它定义了 scan_fn,该函数实际上是 gru_step,其中 params 参数是固定的(通过闭包捕获)。然后它使用此函数、初始隐藏状态和输入序列调用 lax.scan。gru_sequence 函数。输出形状证实最终状态具有隐藏维度,输出序列具有 (序列长度, 隐藏维度) 的维度。使用 lax.scan 的一个好处是,由此产生的 gru_sequence 函数完全兼容 JAX 的其他变换,如 jit、grad 和 vmap。
例如,为了编译 GRU 计算以实现更快的执行,只需使用 jax.jit 包装调用:
# 编译 GRU 函数以提高效率
jit_gru_sequence = jax.jit(gru_sequence)
# 运行编译版本(首次运行包含编译时间)
key, subkey = random.split(key)
input_sequence_2 = random.normal(subkey, (seq_len, input_dim))
final_state_jit, hidden_states_sequence_jit = jit_gru_sequence(gru_params, initial_hidden_state, input_sequence_2)
print("\n运行 JIT 编译版本:")
print("最终隐藏状态形状 (JIT):", final_state_jit.shape)
print("输出序列形状 (JIT):", hidden_states_sequence_jit.shape)
如果你想同时处理一批序列,可以使用 jax.vmap。假设你的输入具有批次维度,例如 (批次大小, 序列长度, 输入维度),你将对初始隐藏状态(批次大小, 隐藏维度)和输入都映射到批次维度:
# VMAP 使用示例(需要批处理输入/状态)
# 假设:
# batch_size = 32
# batched_inputs = random.normal(key, (batch_size, seq_len, input_dim))
# batched_initial_h = jnp.zeros((batch_size, hidden_dim,))
# 映射到批次维度(params=None, initial_h, inputs 的轴 0)
# 注意:参数在批次中共享,因此我们在 in_axes 中使用 None
# batched_gru = jax.vmap(gru_sequence, in_axes=(None, 0, 0))
# final_states_batch, hidden_sequences_batch = batched_gru(gru_params, batched_initial_h, batched_inputs)
# print("批处理最终状态形状:", final_states_batch.shape) # (批次大小, 隐藏维度)
# print("批处理输出序列形状:", hidden_sequences_batch.shape) # (批次大小, 序列长度, 隐藏维度)
同样,你可以使用 jax.grad 计算相对于参数(gru_params)或输入(input_sequence)的梯度,从而使得 GRU 单元可以在更大的模型中进行训练。
这个例子说明了 lax.scan 如何提供一个有效方法,以一种与 JAX 的编译和自动微分能力良好集成的方式,实现复杂的、有状态的序列计算。通过定义单个步骤的逻辑,并让 lax.scan 处理迭代,你可以高效地构建复杂的循环模型。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造