正如我们所见,神经网络 (neural network)本质上是大型的嵌套函数。计算损失函数 (loss function)相对于这个嵌套结构中可能存在的数百万个参数 (parameter)的梯度,需要一种系统的方法。手动重复应用链式法则很快就会变得难以应对且容易出错。此时,计算图提供了一种强大且实用的工具。
计算图是一种将数学表达式或一系列操作表示为有向图的方法。图中的每个节点代表一个变量(输入数据、参数、中间值)或一个操作(如加法、乘法、激活函数 (activation function))。有向边表示数据流以及这些节点之间的依赖关系;从节点 A 到节点 B 的边意味着节点 B 的计算依赖于节点 A 的输出。
可视化计算
让我们看一个简单例子。假设我们有函数 L=(x⋅w+b−y)2。这可以表示一个非常简单的线性模型 z=x⋅w+b 的平方误差损失,其中 y 是目标值。我们可以将其分解为基本操作:
- p=x⋅w (乘法)
- a=p+b (加法)
- d=a−y (减法)
- L=d2 (平方)
我们可以将这个序列表示为计算图:
一个计算图,分解了计算 L=(x⋅w+b−y)2。椭圆形代表变量或计算值,矩形代表操作。边表示计算的方向。
前向传播
前向传播包括从输入到最终输出遍历图来计算表达式的值。你从输入(如 x, y)和参数 (parameter)(如 w, b)的值开始,然后根据其父节点的值计算每个操作节点的值。
对于我们的例子,如果 x=2,w=3,b=1,y=5:
- p=x⋅w=2⋅3=6
- a=p+b=6+1=7
- d=a−y=7−5=2
- L=d2=22=4
图结构清晰地定义了操作的顺序。
反向传播 (backpropagation):图上的反向传播
计算图的真正作用在反向传播过程中变得显而易见,这也是反向传播的实现方式。目标是计算最终输出(我们的损失 L)相对于每个输入和参数 (parameter)的梯度(即 ∂x∂L,∂w∂L,∂b∂L,∂y∂L)。
我们从最终输出节点(L)开始,向后遍历图。梯度计算依赖于在每个节点局部应用的链式法则。
-
从终点开始: 输出相对于自身的梯度总是1: ∂L∂L=1。这是我们“输入”到反向传播中的初始梯度值。
-
向后传播梯度: 对于任何产生输出 z 的节点 N,如果我们知道最终损失 L 相对于 z 的梯度(我们称之为 ∂z∂L),我们可以使用链式法则计算 L 相对于该节点 N 的任何输入 u 的梯度:
∂u∂L=∂z∂L⋅∂u∂z
这里,∂u∂z 是节点 N 处操作相对于其特定输入 u 的局部梯度。
-
在分叉处求和梯度: 如果一个变量(比如更复杂图中的 p)输入到多个后续操作中,其总梯度是每个路径反向传播回来的梯度之和。
让我们为我们的示例图追踪此过程:
-
节点 sq (L=d2):
- 输入梯度: ∂L∂L=1。
- 局部梯度: ∂d∂L=∂d∂(d2)=2d。根据我们的值,2d=2(2)=4。
- 输出梯度(到节点
d): ∂d∂L=∂L∂L⋅∂d∂L=1⋅(2d)=2d=4。
-
节点 sub (d=a−y):
- 输入梯度: ∂d∂L=4。
- 局部梯度: ∂a∂d=1, ∂y∂d=−1。
- 输出梯度:
- 到节点
a: ∂a∂L=∂d∂L⋅∂a∂d=4⋅1=4。
- 到节点
y: ∂y∂L=∂d∂L⋅∂y∂d=4⋅(−1)=−4。
-
节点 add (a=p+b):
- 输入梯度: ∂a∂L=4。
- 局部梯度: ∂p∂a=1, ∂b∂a=1。
- 输出梯度:
- 到节点
p: ∂p∂L=∂a∂L⋅∂p∂a=4⋅1=4。
- 到节点
b: ∂b∂L=∂a∂L⋅∂b∂a=4⋅1=4。
-
节点 mul (p=x⋅w):
- 输入梯度: ∂p∂L=4。
- 局部梯度: ∂x∂p=w, ∂w∂p=x。根据我们的值,∂x∂p=3, ∂w∂p=2。
- 输出梯度:
- 到节点
x: ∂x∂L=∂p∂L⋅∂x∂p=4⋅w=4⋅3=12。
- 到节点
w: ∂w∂L=∂p∂L⋅∂w∂p=4⋅x=4⋅2=8。
我们现在已经计算出所有需要的梯度:∂x∂L=12,∂w∂L=8,∂b∂L=4,以及 ∂y∂L=−4。图结构使我们能够系统地应用链式法则,而不会迷失在嵌套的函数结构中。
与深度学习 (deep learning)框架的相关性
现代深度学习库,如 TensorFlow 和 PyTorch,都高度依赖这种方法。它们根据你在模型代码中定义的操作,自动构建计算图(要么在执行前静态构建,要么在执行时动态构建)。这种图表示使得它们能够:
- 高效地进行前向传播计算。
- 自动计算所有参数 (parameter)的梯度,通过反向传播 (backpropagation)(通常称为“自动微分”或自动求导)。这使开发者无需手动推导和实现梯度计算,这对于复杂架构来说非常重要。
- 优化计算,例如,通过识别可并行化的操作或简化图的某些部分。
理解计算图有助于了解这些强大框架背后的机制,并阐明链式法则如何实现非常深的神经网络 (neural network)的训练。它将潜在复杂的梯度计算过程转变为图的结构化遍历。