训练变分自编码器涉及使用基于梯度的优化方法(如随机梯度下降 (SGD) 或 Adam)来优化证据下界 (ELBO)。这需要计算 ELBO 关于解码器参数 (θ) 和编码器参数 (ϕ) 的梯度。我们来回顾下典型的 VAE 流程:
- 编码器网络,由 ϕ 参数化,接收输入 x 并输出近似后验分布 qϕ(z∣x) 的参数。通常,这是一个对角高斯分布,因此编码器输出均值向量 μϕ(x) 和标准差向量 σϕ(x) (或为保持稳定性的对数方差)。
- 从该分布中采样一个潜在向量 z:z∼qϕ(z∣x)=N(z;μϕ(x),diag(σϕ2(x)))。
- 解码器网络,由 θ 参数化,接收采样的 z 并重建输入,对 pθ(x∣z) 进行建模。
- 计算损失函数 (负 ELBO),包括重建误差 (与 pθ(x∣z) 相关) 以及 qϕ(z∣x) 和先验 p(z) 之间的 KL 散度。
主要问题出现在步骤 2。采样操作 z∼qϕ(z∣x) 将随机性直接引入计算图,在编码器的输出 (μϕ(x),σϕ(x)) 和解码器的输入 (z) 之间。标准的反向传播无法处理此类随机采样节点;从解码器和 KL 散度项到编码器参数 ϕ 的梯度流动被阻断了。我们无法通过随机采样过程直接计算 ∇ϕ。如果采样本身是不可微的,我们如何根据采样 z 的后续影响来调整编码器的参数 ϕ 呢?
这就是重参数化技巧发挥作用的地方。这是一种巧妙的办法,用于重构采样过程,使得梯度能够流回编码器参数。核心思想是把随机性分离出来。我们不直接从由 μϕ(x) 和 σϕ(x) 定义的分布中采样 z,而是引入一个辅助噪声变量 ϵ,它来自一个固定的、简单的分布 (与 x 和 ϕ 无关),然后将 z 表示为 μϕ(x)、σϕ(x) 和 ϵ 的一个确定性函数。
工作原理:高斯分布情形
对于 qϕ(z∣x) 是对角高斯分布 N(μϕ(x),diag(σϕ2(x))) 的常见情况,重参数化按以下方式进行:
- 从标准正态分布中采样一个噪声向量 ϵ:ϵ∼N(0,I),其中 I 是单位矩阵。ϵ 的维度与潜在空间 z 的维度匹配。这一采样步骤发生在与 ϕ 相关的梯度主路径之外。
- 使用编码器输出 μϕ(x)、σϕ(x) 和采样的噪声 ϵ,通过以下确定性变换计算潜在向量 z:
z=μϕ(x)+σϕ(x)⊙ϵ
这里,⊙ 表示按元素乘法。
请注意,以这种方式生成的 z 仍然是一个随机变量,具有所需分布 N(μϕ(x),diag(σϕ2(x))),但随机性来源 (ϵ) 现在已外部化。该变换本身是关于 μϕ(x) 和 σϕ(x) 的一个简单可微函数。
使梯度能够流动
通过重参数化,计算图发生了变化。输入 x 流经编码器,生成 μϕ(x) 和 σϕ(x)。独立地采样一个随机 ϵ。然后,使用上述公式确定性地计算 z。这个 z 被送入解码器以计算重建损失。
重要的是,梯度现在可以从损失函数流回:
- 通过解码器到 z (∇z)。
- 通过确定性变换 z=μϕ(x)+σϕ(x)⊙ϵ 回到 μϕ(x) 和 σϕ(x) (∇μ,∇σ)。请注意,梯度也流经 ϵ,但我们不需要优化 ϵ 的分布。
- 最后,通过编码器网络流回以更新其参数 ϕ (∇ϕ)。
ELBO 中的 KL 散度项 DKL(qϕ(z∣x)∣∣p(z)) 直接依赖于 μϕ(x) 和 σϕ(x),因此其关于 ϕ 的梯度可以直接计算,而无需涉及采样过程。
重参数化技巧有效地将随机节点“移到了一边”,使得包含我们希望优化的参数 (ϕ 和 θ) 的主要计算路径能够完全可微。
重参数化技巧应用前后计算图和梯度流的对比。之前 (左侧),随机采样节点 (红色椭圆) 阻断了从解码器流回与重建损失相关的编码器参数的梯度。之后 (右侧),随机性通过外部变量 ϵ (青色椭圆) 注入,且计算 z 的变换 (靛蓝色方框) 是确定性的,使得梯度 (虚线蓝色线条) 能够流回编码器参数 (μ,σ)。
对优化的影响
通过使整个过程 (从输入 x 到最终损失计算) 关于 ϕ 和 θ 可微,重参数化技巧让我们能够使用标准的基于梯度的优化器。具体来说,我们可以计算 ELBO 梯度的蒙特卡洛估计。对于期望项 Eqϕ(z∣x)[logpθ(x∣z)],我们通常在每个训练步骤中对每个数据点 x 使用单个 ϵ 样本,以获得梯度 ∇ϕEqϕ(z∣x)[logpθ(x∣z)] 的无偏估计。KL 项的梯度通常通过解析方法计算。
虽然这里是针对高斯分布进行说明,但重参数化技巧也可以应用于其他分布,条件是样本可以通过参数的可微变换以及具有固定参数的基础分布来生成 (例如,用于分类分布的 Gumbel-Softmax)。这项技术对训练 VAE 和许多其他涉及从模型架构内参数化分布采样的深度生成模型来说必不可少。大多数深度学习库都提供了常见分布的实现,并内置了对重参数化采样的支持 (通常通过 rsample() 这样的方法)。