简单RNN在处理长距离依赖方面存在困难,部分原因在于它们缺乏明确控制内存的机制。多步之前的信息可能消失或使网络过载。长短期记忆(LSTM)网络引入了特殊组成部分,称为门,用于管理网络内存中存储的信息,其称为细胞状态 (C)。
我们首先查看的门是遗忘门。它的作用直接但重要:它决定细胞状态中的哪些信息应该被丢弃或保留。将细胞状态视为LSTM的长期记忆。当新输入到来时,遗忘门会查看先前的状态和新输入,以判断现有长期记忆的哪些部分不再有意义。
遗忘门如何运作
遗忘门通过将先前的隐藏状态 (ht−1) 和当前输入 (xt) 传入一个S型激活函数 (activation function) (σ) 来运作。S型函数在此处很合适,因为它输出0到1之间的值。
输入 (ht−1 和 xt) 由遗忘门的S型层处理,以生成遗忘向量 (vector) ft。
数学上,计算方法是:
ft=σ(Wf⋅[ht−1,xt]+bf)
我们来分析一下:
- [ht−1,xt]:这表示将先前的隐藏状态向量 ht−1 和当前输入向量 xt 拼接起来。如果 ht−1 的维度是 Nh,而 xt 的维度是 Nx,那么拼接后的向量维度是 Nh+Nx。
- Wf:这是与遗忘门关联的权重 (weight)矩阵。它是训练过程中调整的可学习参数 (parameter)。其维度通常是 (Nh+Nx)×Nh,其中 Nh 是细胞状态(和隐藏状态)的维度。
- bf:这是遗忘门的偏置 (bias)向量,也是一个可学习参数,维度为 Nh。
- σ:S型函数逐元素应用于矩阵乘法和加法的结果。
解释输出 (ft)
输出 ft 是一个向量 (vector),其维度与细胞状态 Ct−1 相同。 ft 中的每个元素都是一个介于0和1之间的数字,对应于细胞状态 Ct−1 中的一个元素。
- ft 中接近 0 的值表示“遗忘”或“丢弃” Ct−1 中的对应信息。
- ft 中接近 1 的值表示“保留”或“记住” Ct−1 中的对应信息。
- 0到1之间的值表示部分保留。
这个向量 ft 就像一个过滤器。它将与先前的细胞状态 Ct−1 进行逐元素相乘,以决定多少旧记忆应该传递到下一步。我们将在本章后面讨论更新细胞状态时看到这种乘法运算。
根据当前输入和过去上下文 (context)(通过 ht−1)选择性遗忘不相关信息的能力是LSTM相较于简单RNN能够长时间保留有用信息的一个主要原因。这避免了细胞状态被过时或不必要的细节所累积。