趋近智
长短期记忆(LSTM)网络以有效解决梯度消失问题并捕获长距离依赖关系而闻名。然而,这些网络通常由于其架构(包含三个独立的门:输入、遗忘、输出以及一个独立的细胞状态)而引入了相当大的计算复杂性。2014年,Cho等人提出了一种名为门控循环单元(GRU)的变体,它在许多任务上实现了相似的表现,但结构更简单。GRU将遗忘门和输入门合并成一个“更新门”,并合并了细胞状态和隐藏状态。
让我们看看GRU单元的结构。与其他RNN一样,它接收当前输入 和前一个隐藏状态 来生成下一个隐藏状态 。其作用通过两个门来实现:重置门 () 和更新门 ()。
重置门决定了如何将新输入与前一个隐藏状态结合。具体来说,它控制在计算候选隐藏状态时,前一个隐藏状态 () 应该“遗忘”多少。其计算方式如下:
这里,、 和 是重置门的可学习权重 (weight)矩阵和偏置 (bias)向量 (vector)。Sigmoid函数 将输出压缩在0到1之间。值接近0表示前一个隐藏状态大部分被忽略,而值接近1表示它大部分被保留。
更新门决定了前一个隐藏状态 () 有多少应该传递到新的隐藏状态 (),以及有多少新的候选隐藏状态应该被使用。这个门基本结合了LSTM的遗忘门和输入门的作用。它的计算方式与重置门类似:
同样,、 和 是可学习参数 (parameter), 是Sigmoid函数。 的值接近1表示前一个状态 大部分被保留,而值接近0表示新的候选状态被主要使用。
在计算最终隐藏状态之前,GRU会计算一个候选隐藏状态 ()。这个计算受到重置门的影响,重置门决定了前一个隐藏状态 的贡献程度:
这里, 表示按元素相乘(哈达玛积)。如果重置门 的值接近0,那么 的贡献将被有效清除,使候选状态主要基于当前输入 。、 和 是另一组可学习的权重 (weight)和偏置 (bias)。 函数有助于调节网络中的值,通常将其压缩在-1到1之间。
最后,更新门 在前一个隐藏状态 和候选隐藏状态 之间进行调节,以生成当前时间步的最终隐藏状态 :
这个方程的作用类似于加权平均。如果 接近1,候选状态 贡献更多,从而有效地用新信息更新隐藏状态。如果 接近0,前一个隐藏状态 被保留更多,允许信息在多个时间步中不变地传递。这个机制是GRU如何维护长距离依赖关系的方式。
GRU单元内信息流的简化视图。 是输入, 是前一个隐藏状态。重置门 () 影响候选状态 (),更新门 () 将候选状态与前一个状态结合,生成最终隐藏状态 。
GRU通常被视为LSTM的更精简替代方案。
在实践中,LSTM和GRU之间的选择通常取决于具体的数据集和任务。两者在所有场景下都没有哪个能持续胜过另一个,尽管GRU因其相对简单和相近的表现而获得欢迎。
以下是一个PyTorch代码片段,呈现了单个GRU步骤的核心计算(假设输入为x_t、h_tm1以及预定义的权重 (weight)/偏置 (bias)张量):
import torch
import torch.nn.functional as F
# 示例张量 (batch_size, input_size/hidden_size)
# 替换为实际维度和已初始化的权重/偏置
batch_size = 1
input_size = 10
hidden_size = 20
x_t = torch.randn(batch_size, input_size)
h_tm1 = torch.randn(batch_size, hidden_size) # h_{t-1}
# --- 假设权重矩阵 (W_*, U_*) 和偏置 (b_*) 已定义 ---
# 示例初始化(替换为实际学习到的参数)
W_r = torch.randn(input_size, hidden_size)
U_r = torch.randn(hidden_size, hidden_size)
b_r = torch.randn(hidden_size)
W_z = torch.randn(input_size, hidden_size)
U_z = torch.randn(hidden_size, hidden_size)
b_z = torch.randn(hidden_size)
W_h = torch.randn(input_size, hidden_size)
U_h = torch.randn(hidden_size, hidden_size)
b_h = torch.randn(hidden_size)
# ---------------------------------------------------------------------
# 重置门计算
r_t = torch.sigmoid(x_t @ W_r + h_tm1 @ U_r + b_r)
# 更新门计算
z_t = torch.sigmoid(x_t @ W_z + h_tm1 @ U_z + b_z)
# 候选隐藏状态计算
h_tilde_t = torch.tanh(x_t @ W_h + (r_t * h_tm1) @ U_h + b_h)
# 最终隐藏状态计算
h_t = (1 - z_t) * h_tm1 + z_t * h_tilde_t
print("前一个隐藏状态的形状:", h_tm1.shape)
print("当前隐藏状态的形状:", h_t.shape)
这种简化结构虽然有效,但仍然依赖于顺序处理。时间步 的计算依赖于时间步 的结果。这种固有的顺序依赖性限制了训练期间的并行化,并且在处理非常长的序列时仍然是一个瓶颈,这为Transformer中使用的非循环注意力机制 (attention mechanism)奠定了基础。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造