在明确了批标准化(BN)在前向传播过程中如何工作,即使用小批量统计数据对输入进行归一化,再进行缩放和平移之后,我们现在转向反向传播。为了通过梯度下降训练网络,我们需计算损失 L L L 相对于BN层输入 x i x_i x i 及可学习参数 γ \gamma γ 和 β \beta β 的变化情况。这要求通过BN变换应用链式法则。
我们回顾一下迷你批次 B = { x 1 , . . . , x m } \mathcal{B} = \{x_1, ..., x_m\} B = { x 1 , ... , x m } 中单个激活值 x i x_i x i 的前向传播过程:
计算迷你批次均值:
μ B = 1 m ∑ i = 1 m x i \mu_\mathcal{B} = \frac{1}{m} \sum_{i=1}^{m} x_i μ B = m 1 i = 1 ∑ m x i
计算迷你批次方差:
σ B 2 = 1 m ∑ i = 1 m ( x i − μ B ) 2 \sigma_\mathcal{B}^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_\mathcal{B})^2 σ B 2 = m 1 i = 1 ∑ m ( x i − μ B ) 2
归一化输入:
x ^ i = x i − μ B σ B 2 + ϵ \hat{x}_i = \frac{x_i - \mu_\mathcal{B}}{\sqrt{\sigma_\mathcal{B}^2 + \epsilon}} x ^ i = σ B 2 + ϵ x i − μ B
(其中 ϵ \epsilon ϵ 是一个用于数值稳定的小常数)
缩放和平移:
y i = γ x ^ i + β y_i = \gamma \hat{x}_i + \beta y i = γ x ^ i + β
在反向传播过程中,我们从后续层接收到损失相对于BN层输出的梯度 ∂ L ∂ y i \frac{\partial L}{\partial y_i} ∂ y i ∂ L 。我们的目标是计算 ∂ L ∂ x i \frac{\partial L}{\partial x_i} ∂ x i ∂ L 、∂ L ∂ γ \frac{\partial L}{\partial \gamma} ∂ γ ∂ L 和 ∂ L ∂ β \frac{\partial L}{\partial \beta} ∂ β ∂ L 。
可学习参数(γ \gamma γ 和 β \beta β )的梯度
这些是使用链式法则计算的最直接的梯度:
相对于 β \beta β 的梯度 :参数 β \beta β 直接加到输出 y i y_i y i 上。
∂ L ∂ β = ∑ i = 1 m ∂ L ∂ y i ∂ y i ∂ β = ∑ i = 1 m ∂ L ∂ y i ( 1 ) = ∑ i = 1 m ∂ L ∂ y i \frac{\partial L}{\partial \beta} = \sum_{i=1}^{m} \frac{\partial L}{\partial y_i} \frac{\partial y_i}{\partial \beta} = \sum_{i=1}^{m} \frac{\partial L}{\partial y_i} (1) = \sum_{i=1}^{m} \frac{\partial L}{\partial y_i} ∂ β ∂ L = i = 1 ∑ m ∂ y i ∂ L ∂ β ∂ y i = i = 1 ∑ m ∂ y i ∂ L ( 1 ) = i = 1 ∑ m ∂ y i ∂ L
β \beta β 的梯度就是来自输出 y i y_i y i 的传入梯度的总和。
相对于 γ \gamma γ 的梯度 :参数 γ \gamma γ 缩放归一化输入 x ^ i \hat{x}_i x ^ i 。
∂ L ∂ γ = ∑ i = 1 m ∂ L ∂ y i ∂ y i ∂ γ = ∑ i = 1 m ∂ L ∂ y i ( x ^ i ) = ∑ i = 1 m ∂ L ∂ y i x ^ i \frac{\partial L}{\partial \gamma} = \sum_{i=1}^{m} \frac{\partial L}{\partial y_i} \frac{\partial y_i}{\partial \gamma} = \sum_{i=1}^{m} \frac{\partial L}{\partial y_i} (\hat{x}_i) = \sum_{i=1}^{m} \frac{\partial L}{\partial y_i} \hat{x}_i ∂ γ ∂ L = i = 1 ∑ m ∂ y i ∂ L ∂ γ ∂ y i = i = 1 ∑ m ∂ y i ∂ L ( x ^ i ) = i = 1 ∑ m ∂ y i ∂ L x ^ i
γ \gamma γ 的梯度是传入梯度的总和,每个梯度都由对应的归一化输入 x ^ i \hat{x}_i x ^ i 加权。
相对于输入(x i x_i x i )的梯度
计算相对于输入 x i x_i x i 的梯度更需细致考虑,因为 x i x_i x i 通过多种方式影响输出 y i y_i y i :
直接通过 x ^ i \hat{x}_i x ^ i 中的分子 ( x i − μ B ) (x_i - \mu_\mathcal{B}) ( x i − μ B ) 。
间接通过迷你批次均值 μ B \mu_\mathcal{B} μ B ,它取决于批次中的所有 x j x_j x j 。
间接通过迷你批次方差 σ B 2 \sigma_\mathcal{B}^2 σ B 2 ,它也取决于所有 x j x_j x j (包括 x i x_i x i )和 μ B \mu_\mathcal{B} μ B 。
我们需要仔细应用链式法则,考虑所有这些路径。设 σ B , ϵ = σ B 2 + ϵ \sigma_{\mathcal{B},\epsilon} = \sqrt{\sigma_\mathcal{B}^2 + \epsilon} σ B , ϵ = σ B 2 + ϵ 。梯度计算通过以下操作逆向进行:
相对于归一化输入 x ^ i \hat{x}_i x ^ i 的梯度 :
∂ L ∂ x ^ i = ∂ L ∂ y i ∂ y i ∂ x ^ i = ∂ L ∂ y i γ \frac{\partial L}{\partial \hat{x}_i} = \frac{\partial L}{\partial y_i} \frac{\partial y_i}{\partial \hat{x}_i} = \frac{\partial L}{\partial y_i} \gamma ∂ x ^ i ∂ L = ∂ y i ∂ L ∂ x ^ i ∂ y i = ∂ y i ∂ L γ
相对于 μ B \mu_\mathcal{B} μ B 和 σ B 2 \sigma_\mathcal{B}^2 σ B 2 的梯度 :这些需要汇总迷你批次中所有 x ^ j \hat{x}_j x ^ j 的贡献,因为这两个统计量都会影响所有归一化输入。
∂ L ∂ σ B 2 = ∑ i = 1 m ∂ L ∂ x ^ i ∂ x ^ i ∂ σ B 2 = ∑ i = 1 m ∂ L ∂ x ^ i ( x i − μ B ) ( − 1 2 ( σ B 2 + ϵ ) − 3 / 2 ) \frac{\partial L}{\partial \sigma_\mathcal{B}^2} = \sum_{i=1}^{m} \frac{\partial L}{\partial \hat{x}_i} \frac{\partial \hat{x}_i}{\partial \sigma_\mathcal{B}^2} = \sum_{i=1}^{m} \frac{\partial L}{\partial \hat{x}_i} (x_i - \mu_\mathcal{B}) \left( -\frac{1}{2} (\sigma_\mathcal{B}^2 + \epsilon)^{-3/2} \right) ∂ σ B 2 ∂ L = i = 1 ∑ m ∂ x ^ i ∂ L ∂ σ B 2 ∂ x ^ i = i = 1 ∑ m ∂ x ^ i ∂ L ( x i − μ B ) ( − 2 1 ( σ B 2 + ϵ ) − 3/2 )
∂ L ∂ μ B = ∑ i = 1 m ∂ L ∂ x ^ i ∂ x ^ i ∂ μ B = ( ∑ i = 1 m ∂ L ∂ x ^ i − 1 σ B , ϵ ) + ∂ L ∂ σ B 2 ∂ σ B 2 ∂ μ B \frac{\partial L}{\partial \mu_\mathcal{B}} = \sum_{i=1}^{m} \frac{\partial L}{\partial \hat{x}_i} \frac{\partial \hat{x}_i}{\partial \mu_\mathcal{B}} = \left( \sum_{i=1}^{m} \frac{\partial L}{\partial \hat{x}_i} \frac{-1}{\sigma_{\mathcal{B},\epsilon}} \right) + \frac{\partial L}{\partial \sigma_\mathcal{B}^2} \frac{\partial \sigma_\mathcal{B}^2}{\partial \mu_\mathcal{B}} ∂ μ B ∂ L = i = 1 ∑ m ∂ x ^ i ∂ L ∂ μ B ∂ x ^ i = ( i = 1 ∑ m ∂ x ^ i ∂ L σ B , ϵ − 1 ) + ∂ σ B 2 ∂ L ∂ μ B ∂ σ B 2
其中 ∂ σ B 2 ∂ μ B = 1 m ∑ j = 1 m 2 ( x j − μ B ) ( − 1 ) = − 2 m ∑ j = 1 m ( x j − μ B ) = 0 \frac{\partial \sigma_\mathcal{B}^2}{\partial \mu_\mathcal{B}} = \frac{1}{m} \sum_{j=1}^{m} 2(x_j - \mu_\mathcal{B})(-1) = \frac{-2}{m} \sum_{j=1}^{m} (x_j - \mu_\mathcal{B}) = 0 ∂ μ B ∂ σ B 2 = m 1 ∑ j = 1 m 2 ( x j − μ B ) ( − 1 ) = m − 2 ∑ j = 1 m ( x j − μ B ) = 0 。
因此,第二项消失,简化了 μ B \mu_\mathcal{B} μ B 的梯度:
∂ L ∂ μ B = ∑ i = 1 m ∂ L ∂ x ^ i − 1 σ B , ϵ \frac{\partial L}{\partial \mu_\mathcal{B}} = \sum_{i=1}^{m} \frac{\partial L}{\partial \hat{x}_i} \frac{-1}{\sigma_{\mathcal{B},\epsilon}} ∂ μ B ∂ L = i = 1 ∑ m ∂ x ^ i ∂ L σ B , ϵ − 1
相对于输入 x i x_i x i 的梯度 :现在我们将路径结合起来。输入 x i x_i x i 通过 x ^ i \hat{x}_i x ^ i 、μ B \mu_\mathcal{B} μ B 和 σ B 2 \sigma_\mathcal{B}^2 σ B 2 影响损失。
∂ L ∂ x i = ∂ L ∂ x ^ i ∂ x ^ i ∂ x i + ∂ L ∂ σ B 2 ∂ σ B 2 ∂ x i + ∂ L ∂ μ B ∂ μ B ∂ x i \frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial \hat{x}_i} \frac{\partial \hat{x}_i}{\partial x_i} + \frac{\partial L}{\partial \sigma_\mathcal{B}^2} \frac{\partial \sigma_\mathcal{B}^2}{\partial x_i} + \frac{\partial L}{\partial \mu_\mathcal{B}} \frac{\partial \mu_\mathcal{B}}{\partial x_i} ∂ x i ∂ L = ∂ x ^ i ∂ L ∂ x i ∂ x ^ i + ∂ σ B 2 ∂ L ∂ x i ∂ σ B 2 + ∂ μ B ∂ L ∂ x i ∂ μ B
我们需要统计量相对于单个输入 x i x_i x i 的偏导数:
∂ x ^ i ∂ x i = 1 σ B , ϵ \frac{\partial \hat{x}_i}{\partial x_i} = \frac{1}{\sigma_{\mathcal{B},\epsilon}} ∂ x i ∂ x ^ i = σ B , ϵ 1 (直接路径,忽略此项通过均值/方差的依赖关系)
∂ σ B 2 ∂ x i = 2 ( x i − μ B ) m \frac{\partial \sigma_\mathcal{B}^2}{\partial x_i} = \frac{2(x_i - \mu_\mathcal{B})}{m} ∂ x i ∂ σ B 2 = m 2 ( x i − μ B )
∂ μ B ∂ x i = 1 m \frac{\partial \mu_\mathcal{B}}{\partial x_i} = \frac{1}{m} ∂ x i ∂ μ B = m 1
代入这些项,得到 ∂ L ∂ x i \frac{\partial L}{\partial x_i} ∂ x i ∂ L 的最终表达式:
∂ L ∂ x i = ∂ L ∂ x ^ i 1 σ B , ϵ + ∂ L ∂ σ B 2 2 ( x i − μ B ) m + ∂ L ∂ μ B 1 m \frac{\partial L}{\partial x_i} = \frac{\partial L}{\partial \hat{x}_i} \frac{1}{\sigma_{\mathcal{B},\epsilon}} + \frac{\partial L}{\partial \sigma_\mathcal{B}^2} \frac{2(x_i - \mu_\mathcal{B})}{m} + \frac{\partial L}{\partial \mu_\mathcal{B}} \frac{1}{m} ∂ x i ∂ L = ∂ x ^ i ∂ L σ B , ϵ 1 + ∂ σ B 2 ∂ L m 2 ( x i − μ B ) + ∂ μ B ∂ L m 1
综合所有并简化(完整的推导过程相当细致,通常在论文或教科书的附录中可查阅),结果可以更紧凑地表示。一种常见形式是:
∂ L ∂ x i = 1 m σ B , ϵ ( m ∂ L ∂ x ^ i − ∑ j = 1 m ∂ L ∂ x ^ j − x ^ i ∑ j = 1 m ∂ L ∂ x ^ j x ^ j ) \frac{\partial L}{\partial x_i} = \frac{1}{m \sigma_{\mathcal{B},\epsilon}} \left( m \frac{\partial L}{\partial \hat{x}_i} - \sum_{j=1}^{m} \frac{\partial L}{\partial \hat{x}_j} - \hat{x}_i \sum_{j=1}^{m} \frac{\partial L}{\partial \hat{x}_j} \hat{x}_j \right) ∂ x i ∂ L = m σ B , ϵ 1 ( m ∂ x ^ i ∂ L − j = 1 ∑ m ∂ x ^ j ∂ L − x ^ i j = 1 ∑ m ∂ x ^ j ∂ L x ^ j )
请注意 ∂ L ∂ x ^ j = ∂ L ∂ y j γ \frac{\partial L}{\partial \hat{x}_j} = \frac{\partial L}{\partial y_j} \gamma ∂ x ^ j ∂ L = ∂ y j ∂ L γ 。
主要的一点是,梯度 ∂ L ∂ x i \frac{\partial L}{\partial x_i} ∂ x i ∂ L 不仅取决于与该特定激活对应的梯度 ∂ L ∂ y i \frac{\partial L}{\partial y_i} ∂ y i ∂ L ,还由于共享的均值和方差计算,取决于迷你批次中所有其他激活 (j = 1... m j=1...m j = 1... m )的梯度和值。
梯度流的可视化
反向传播过程中的依赖关系可以被可视化。我们考虑单个输出 y i y_i y i 的计算图,以及损失梯度如何流回输入 x i x_i x i ,其中包含了共享的 μ B \mu_\mathcal{B} μ B 和 σ B 2 \sigma_\mathcal{B}^2 σ B 2 的影响。
此图表展示了批标准化计算中的依赖关系以及反向传播过程中梯度的流动。请注意,输入 x i x_i x i 如何直接从 x ^ i \hat{x}_i x ^ i 接收梯度贡献,并间接通过迷你批次统计量 μ B \mu_\mathcal{B} μ B 和 σ B 2 \sigma_\mathcal{B}^2 σ B 2 接收。
框架中的实现
幸运的是,你很少需要手动实现此反向传播。当您定义一个包含BN层的模型时,PyTorch和TensorFlow等深度学习框架会使用自动微分(autograd)来自动计算这些梯度。例如,在PyTorch中:
import torch
import torch.nn as nn
# 示例设置
batch_size = 4
num_features = 10
input_tensor = torch.randn(batch_size, num_features, requires_grad=True)
# 定义一个批标准化层(affine=True 表示 gamma 和 beta 是可学习的)
bn_layer = nn.BatchNorm1d(num_features=num_features, affine=True)
# 前向传播
output = bn_layer(input_tensor)
# 假设一个用于演示的虚拟损失
loss = output.mean()
# 反向传播
loss.backward()
# 梯度现在已被计算和存储
# 相对于输入的梯度: input_tensor.grad
# 相对于 gamma(权重)的梯度: bn_layer.weight.grad
# 相对于 beta(偏置)的梯度: bn_layer.bias.grad
print("输入梯度的形状:", input_tensor.grad.shape)
print("gamma 梯度的形状:", bn_layer.weight.grad.shape)
print("beta 梯度的形状:", bn_layer.bias.grad.shape)
# >>> 输入梯度的形状: torch.Size([4, 10])
# >>> gamma 梯度的形状: torch.Size([10])
# >>> beta 梯度的形状: torch.Size([10])
尽管框架处理了具体实现,但理解其背后的计算原理,特别是在输入梯度 ∂ L ∂ x i \frac{\partial L}{\partial x_i} ∂ x i ∂ L 上对整个迷你批次的依赖,对于理解模型行为和训练期间可能出现的问题很有价值。这种理解有助于我们知晓BN为何影响训练动态和泛化性能,我们将在下一部分讨论这一点。