训练神经网络需要调整其权重以最小化损失函数 L。梯度下降要求计算该损失函数相对于网络中每个权重和偏置的梯度,通常表示为 ∇L。对于一个可能拥有数百万个参数并分布在多层中的网络,计算网络内部较深层权重对最终损失的影响是一项挑战。损失函数并非该特定权重的直接函数;它是网络输出的函数,而网络输出又依赖于中间激活值,中间激活值又进一步依赖于更早的激活值和权重,形成了一条长长的依赖链。
微积分中的一个基本原理在此发挥作用:链式法则。链式法则提供了一种计算复合函数(即相互嵌套的函数)导数的方法。这正是我们在神经网络中遇到的情况。
链式法则:复合函数的导数
让我们回顾一下基本思路。假设我们有一个简单的函数组合。如果变量 y 依赖于变量 u,写作 y=f(u),而 u 又依赖于另一个变量 x,写作 u=g(x),那么 y 通过 u 间接依赖于 x:y=f(g(x))。
链式法则告诉我们如何求 y 相对于 x 的变化率,表示为 dy/dx。它表明这个导数是“外层”函数相对于其输入的导数与“内层”函数相对于其输入的导数的乘积:
dxdy=dudy×dxdu
可以将其视为敏感度的传递。如果 x 发生微小变化, y 会变化多少?这取决于当 x 变化时 u 变化了多少 (du/dx),再乘以当 u 变化时 y 变化了多少 (dy/du)。
我们可以将此推广到更长的链条。如果 z=f(y),y=g(x),且 x=h(w),那么 z 通过这个链条依赖于 w。z 相对于 w 的导数可以通过将路径上的导数相乘得到:
dwdz=dydz×dxdy×dwdx
将链式法则应用于神经网络
现在,让我们将其与神经网络联系起来。考虑一个非常简单的网络,它有一个输入 x,一个隐藏神经元 h,和一个输出神经元 y。计算步骤如下:
- 隐藏神经元的预激活值:z1=w1x+b1
- 隐藏神经元的激活值:h=σ(z1)(其中 σ 是激活函数)
- 输出神经元的预激活值:z2=w2h+b2
- 输出神经元的激活值(预测值):y=σ(z2)
- 损失计算:L=Loss(y,ytrue)(例如,平方误差 L=(y−ytrue)2)
简单神经网络中的前向传播计算。箭头表示依赖关系。计算损失 L 相对于早期权重(如 w1)的梯度,需要使用链式法则,通过这些依赖关系向后传递导数。
假设我们要找到损失 L 相对于权重 w1 的梯度。L 依赖于 y, y 依赖于 z2, z2 依赖于 h, h 依赖于 z1,而 z1 依赖于 w1。使用链式法则,我们可以将导数 ∂L/∂w1 写为:
∂w1∂L=∂y∂L×∂z2∂y×∂h∂z2×∂z1∂h×∂w1∂z1
让我们逐项分析:
- ∂L/∂y:损失函数如何随最终预测值 y 变化。这取决于所使用的具体损失函数。对于 L=(y−ytrue)2,其值为 2(y−ytrue)。
- ∂y/∂z2:输出激活值如何随其预激活值 z2 变化。这是输出激活函数的导数,即 σ′(z2)。
- ∂z2/∂h:输出预激活值 z2 如何随隐藏激活值 h 变化。由于 z2=w2h+b2,此导数即为 w2。
- ∂h/∂z1:隐藏激活值如何随其预激活值 z1 变化。这是隐藏层激活函数的导数,即 σ′(z1)。
- ∂z1/∂w1:隐藏预激活值 z1 如何随权重 w1 变化。由于 z1=w1x+b1,此导数即为 x。
将这个特定例子中的所有项组合起来:
∂w1∂L=∂L/∂y2(y−ytrue)×∂y/∂z2σ′(z2)×∂z2/∂hw2×∂h/∂z1σ′(z1)×∂z1/∂w1x
注意,此计算涉及到在前向传播过程中计算出的项(如 x,h,y,z1,z2),以及在预激活值处评估的激活函数导数。链式法则提供了一种系统方法,可以将这些局部导数相乘,以求得损失函数对特定权重的整体敏感度,无论该权重在网络中处于何层。
在每层有多个神经元的多层网络中,依赖关系变得更加复杂(一个神经元的输出会影响下一层中的多个神经元),涉及导数的求和。然而,核心原理不变:链式法则允许我们通过逐层向后传播导数信息来计算梯度。链式法则的这种系统应用正是反向传播算法所实现的,我们将在接下来的章节中详细说明。理解链式法则,是理解神经网络如何学习的重要一步。