简单的循环神经网络在学习长序列中的依赖关系时面临很大困难,主要原因是梯度消失和梯度爆炸。网络从较早时间步传递相关信息的能力受到影响。长短期记忆(LSTM)网络正是为了解决这一局限而开发,通过引入更复杂的单元结构,使其能够长时间保持记忆。
LSTM网络的主要组成部分是LSTM单元。它用一个由门控和专用单元状态组成的复杂系统,替代了标准RNN单元中的简单变换。这种结构使得网络能够随时间选择性地添加、删除或保留信息。
LSTM单元有两个主要组成部分,实现了这种受控的信息流动:
- 单元状态 (Ct): 这是LSTM的一个突出特性。你可以将其视为一条内部记忆轨道,水平贯穿一系列单元。信息可以在这条状态线上流动,只发生轻微的线性交互。这种结构使得信息(以及训练时的梯度)更容易在多个时间步中持续存在,而不会明显衰减。
- 门控: 它们是神经网络层(通常使用Sigmoid σ激活函数),用于调节信息进出单元状态的流动。因为Sigmoid函数的输出值介于0和1之间,这些门控就像过滤器一样发挥作用:接近0的值表示“只让很少信息通过”,而接近1的值表示“让大部分信息通过”。一个LSTM单元包含三个主要门控:
- 遗忘门: 决定从上一时间步的单元状态 (Ct−1) 中丢弃哪些信息。
- 输入门: 决定将当前输入 (xt) 和上一时间步的隐藏状态 (ht−1) 中的哪些新信息存储到单元状态中。
- 输出门: 控制当前单元状态 (Ct) 的哪些部分应作为当前时间步的隐藏状态 (ht) 传递出去。
我们来描绘这些组成部分如何在一个时间步 t 的单个LSTM单元中相互作用:
数据流和组成部分在一个时间步 t 的单个LSTM单元中。它接收当前输入 xt、上一隐藏状态 ht−1 和上一单元状态 Ct−1。它计算新的单元状态 Ct 和新的隐藏状态 ht。圆圈表示运算(Sigmoid σ、双曲正切 tanh、元素级乘法 ×、元素级加法 +)。
LSTM单元内的信息处理过程如下:
-
遗忘门 (ft): 单元首先决定从上一单元状态 Ct−1 中丢弃哪些信息。它查看上一隐藏状态 ht−1 和当前输入 xt。这些数据通过一个Sigmoid函数 σ。输出 ft 包含 Ct−1 中每个数值介于0到1之间的值。1表示“完全保留”,而0表示“完全丢弃”。
ft=σ(Wf[ht−1,xt]+bf)
-
输入门 (it) 和候选值 (C~t): 接下来,单元决定要将哪些新信息存储到单元状态中。这包含两个步骤:
- 一个Sigmoid层,称为“输入门层”,接收 ht−1 和 xt,并决定更新哪些值 (it)。
- 一个 tanh 层接收 ht−1 和 xt,并生成一个新候选值向量 C~t,这些值有可能被添加到状态中。
it=σ(Wi[ht−1,xt]+bi)
C~t=tanh(WC[ht−1,xt]+bC)
-
单元状态更新 (Ct): 现在,旧的单元状态 Ct−1 被更新为新的单元状态 Ct。上一状态 Ct−1 与遗忘向量 ft 进行元素级乘法 (⊙),从而遗忘选定的部分。接着,将元素级乘法 it⊙C~t 的结果(新信息,根据我们决定更新的程度进行缩放)相加。
Ct=ft⊙Ct−1+it⊙C~t
-
输出门 (ot) 和隐藏状态 (ht): 最后,单元决定输出,即隐藏状态 ht。这个输出是单元状态的过滤版本。
- 首先,一个Sigmoid层接收 ht−1 和 xt,以决定单元状态的哪些部分应该作为输出 (ot)。
- 然后,新计算出的单元状态 Ct 通过 tanh(将值压缩到-1和1之间)。
- 这个 tanh(Ct) 与输出门的激活值 ot 进行元素级乘法 (⊙)。这会产生隐藏状态 ht。
ot=σ(Wo[ht−1,xt]+bo)
ht=ot⊙tanh(Ct)
在上述方程中,Wf,Wi,WC,Wo 代表权重矩阵,bf,bi,bC,bo 是偏置向量,它们在训练过程中被学习。符号 [ht−1,xt] 通常表示这两个向量的连接。
这种门控结构,特别是独立的单元状态 Ct,它仅通过门控进行轻微的加法和乘法运算,这是LSTM比简单RNN更擅长捕获长距离依赖关系的原因。信息可以在许多时间步中保持,门控学习控制哪些信息是相关的,通过提供更直接的梯度传播路径,有效地缓解了梯度消失问题。