我们将链式法则与训练神经网络 (neural network)的机制直接关联起来。正如我们之前讨论的,神经网络本质上是一个大型复合函数。最终输出(例如,分类分数或回归值)依赖于前一层的输出,而前一层的输出又依赖于再前一层的输出,依此类推,一直追溯到输入数据。重要的一点是,每个层计算的函数都包含权重 (weight)和偏置 (bias),这些是我们需要的调整参数 (parameter)。
训练过程包括最小化一个损失函数 (loss function),我们称之为 L,它衡量网络预测值与实际目标值之间的差异。为了使用梯度下降 (gradient descent)(或其变体)最小化 L,我们需要计算 L 对网络中每个权重 W 和偏置 b 的偏导数:∂W∂L 和 ∂b∂L。考虑到这种嵌套结构,直接计算这些似乎很困难。此时,链式法则就变得非常有用。
反向传播 (backpropagation)不是一种新的优化算法;它是一种在神经网络中计算梯度的高效算法。它通过从损失函数反向工作,系统地应用多元链式法则来计算所有必需的偏导数(对于所有层 l、神经元 i 和输入连接 j 的 ∂Wij[l]∂L 和 ∂bi[l]∂L)。
反向传播 (backpropagation)过程详解
设想前向传播过程:输入数据流经网络,层层递进,经历线性变换(乘以权重 (weight),加上偏置 (bias))和非线性激活,最终生成输出预测,然后是损失值 L。
反向传播则逆转这个流程:
-
从末端开始: 该过程始于计算损失 L 对网络最终输出激活值 a[L](其中 L 表示最后一层)的导数。这通常很简单,取决于所使用的具体损失函数 (loss function)和最终激活函数 (activation function)。我们将 ∂a[L]∂L 记为 δa[L]。
-
关于预激活值 (z[L]) 的梯度: 利用链式法则,我们找到损失对最后一层预激活线性输出 z[L] 的梯度。如果 a[L]=g(z[L]),其中 g 是激活函数,那么:
∂z[L]∂L=∂a[L]∂L⋅∂z[L]∂a[L]=δa[L]⋅g′(z[L])
我们将 ∂z[L]∂L 记为 δz[L]。项 g′(z[L]) 是激活函数在 z[L] 处求得的导数。
-
关于参数 (parameter) (W[L],b[L]) 的梯度: 现在我们有了 δz[L],我们可以找到最后一层的权重 W[L] 和偏置 b[L] 的梯度。因为 z[L]=W[L]a[L−1]+b[L]:
∂W[L]∂L=∂z[L]∂L⋅∂W[L]∂z[L]=δz[L]⋅(a[L−1])T
∂b[L]∂L=∂z[L]∂L⋅∂b[L]∂z[L]=δz[L]⋅1=δz[L]
(注意:∂W[L]∂L 的计算涉及矩阵/向量 (vector)运算,结果是一个与 W[L] 同形状的梯度矩阵。对于 ∂b[L]∂L,我们通常会在批量维度上对 δz[L] 求和。)
-
将梯度传播到前一层: 真正巧妙的部分是将误差梯度反向传播。我们需要找到 ∂a[L−1]∂L,即损失对前一层的激活值的梯度。再次通过 z[L] 使用链式法则:
∂a[L−1]∂L=∂z[L]∂L⋅∂a[L−1]∂z[L]=δz[L]⋅(W[L])T
我们将此记为 δa[L−1]。这表明了 L−1 层中的激活值对最终损失 L 的影响程度。
-
对 L−1 层重复: 现在我们有了 δa[L−1]。我们可以对 L−1 层重复步骤 2、3 和 4:
- 计算 δz[L−1]=δa[L−1]⋅g′(z[L−1])。
- 计算 ∂W[L−1]∂L=δz[L−1]⋅(a[L−2])T。
- 计算 ∂b[L−1]∂L=δz[L−1]。
- 计算 δa[L−2]=δz[L−1]⋅(W[L−1])T。
-
持续到输入层: 这个过程会重复,逐层反向进行,直到我们到达输入层。在每一步 l,我们使用传入的梯度 δa[l](从 l+1 层计算得到)来计算 δz[l],然后计算 W[l] 和 b[l] 的梯度,最后计算梯度 δa[l−1] 以传回给前一层。
这种系统性的链式法则反向应用,保证我们能够高效地计算损失 L 对网络中每个参数的梯度,并尽可能地重用计算。所有层的梯度 ∂W[l]∂L 和 ∂b[l]∂L 随后被用于梯度下降 (gradient descent)等优化算法来更新参数:
W[l]:=W[l]−α∂W[l]∂L
b[l]:=b[l]−α∂b[l]∂L
其中 α 是学习率。
这是一个两层网络中前向传播(计算损失)和反向传播(计算梯度)的简化视图。反向传播从损失开始,逐层应用链式法则,以计算损失对每个参数 (W[l],b[l]) 和激活值 (a[l]) 的梯度。红色节点和箭头表示梯度的流动和计算。
理解这种反向流动以及链式法则在每一步中的应用,是掌握神经网络 (neural network)如何学习的基础。下一节将讨论计算图如何帮助更清晰地展示这个过程。