变分自编码器 (VAE) 是使用变分推断的原理推导出来的。生成建模中的一个常见目标是估计观测数据 p(x) 的概率分布。对于带有潜在变量 z 的模型,这需要计算边际似然:
p(x)=∫p(x,z)dz=∫pθ(x∣z)p(z)dz
其中 p(z) 是潜在变量的先验分布,而 pθ(x∣z) 是给定潜在变量的数据似然,通常由一个参数为 θ 的解码器网络参数化。
然而,对于复杂模型和高维潜在空间,此积分通常难以计算。这种难以计算性也体现在真实的后验分布 p(z∣x)=p(x∣z)p(z)/p(x) 上,因为其分母 p(x) 正是我们无法计算的积分。变分推断通过引入真实后验的近似 qϕ(z∣x) 来解决此问题。此近似后验通常由一个参数为 ϕ 的编码器网络参数化。
核心思想是使 qϕ(z∣x) 尽可能接近真实的后验 p(z∣x)。我们使用库尔巴克-莱布勒 (KL) 散度 DKL(qϕ(z∣x)∣∣p(z∣x)) 来衡量这种“接近程度”。我们的目标是找到使此 KL 散度最小化的参数 ϕ。
让我们从数据对数似然 logp(x) 开始,看看 qϕ(z∣x) 和证据下界 (ELBO) 如何得到。
logp(x)=log∫p(x,z)dz
我们可以在积分号内乘以并除以 qϕ(z∣x)(假设当 p(x,z)>0 时 qϕ(z∣x)>0):
logp(x)=log∫qϕ(z∣x)qϕ(z∣x)p(x,z)dz
这可以重写为关于 qϕ(z∣x) 的期望的对数:
logp(x)=logEqϕ(z∣x)[qϕ(z∣x)p(x,z)]
由于对数是一个凹函数,我们可以应用琴生不等式 (logE[Y]≥E[logY]) 将对数移到期望内部:
logp(x)≥Eqϕ(z∣x)[logqϕ(z∣x)p(x,z)]
这个下界正是证据下界 (ELBO),通常表示为 LELBO 或简写为 L(ϕ,θ;x):
LELBO(ϕ,θ;x)=Eqϕ(z∣x)[logp(x,z)−logqϕ(z∣x)]
通过展开 p(x,z)=pθ(x∣z)p(z),我们得到另一个常用形式:
LELBO(ϕ,θ;x)=Eqϕ(z∣x)[logpθ(x∣z)+logp(z)−logqϕ(z∣x)]
真实对数似然 logp(x) 和 ELBO 之间的差值正是近似后验与真实后验之间的 KL 散度:
logp(x)−LELBO(ϕ,θ;x)=Eqϕ(z∣x)[logqϕ(z∣x)−logp(z∣x)]=DKL(qϕ(z∣x)∣∣p(z∣x))
所以,我们得到以下基本关系:
logp(x)=LELBO(ϕ,θ;x)+DKL(qϕ(z∣x)∣∣p(z∣x))
由于 KL 散度始终非负 (DKL≥0),ELBO 确实是数据对数似然的下界。相对于 ϕ 和 θ 最大化 ELBO 有以下两个作用:
- 它使 ELBO 更接近真实对数似然,有效地最小化 KL 散度 DKL(qϕ(z∣x)∣∣p(z∣x)),从而使我们的近似后验 qϕ(z∣x) 更好地近似真实的后验 p(z∣x)。
- 它间接最大化了我们的模型生成观测数据的对数似然 logp(x)。
对数边际似然 logp(x) 分解为证据下界 (ELBO) 和近似后验 qϕ(z∣x) 与真实后验 p(z∣x) 之间的 KL 散度。最大化 ELBO 有效地最大化了 logp(x),同时最小化了近似误差。
剖析 ELBO
ELBO 可以重排成一个更易于理解的形式,突出 VAE 的两个主要目标:
从 LELBO=Eqϕ(z∣x)[logpθ(x∣z)+logp(z)−logqϕ(z∣x)] 开始,我们可以将项分组:
LELBO(ϕ,θ;x)=Eqϕ(z∣x)[logpθ(x∣z)]−Eqϕ(z∣x)[logqϕ(z∣x)−logp(z)]
第二项是 qϕ(z∣x) 和 p(z) 之间 KL 散度的定义:
DKL(qϕ(z∣x)∣∣p(z))=Eqϕ(z∣x)[logp(z)qϕ(z∣x)]=Eqϕ(z∣x)[logqϕ(z∣x)−logp(z)]
因此,ELBO 变为:
LELBO(ϕ,θ;x)=重构似然Eqϕ(z∣x)[logpθ(x∣z)]−KL 正则项DKL(qϕ(z∣x)∣∣p(z))
让我们分别分析这两个组成部分:
-
期望重构对数似然: Eqϕ(z∣x)[logpθ(x∣z)]
此项衡量解码器 pθ(x∣z) 在给定从编码器近似后验 qϕ(z∣x) 中采样的潜在编码 z 时,能够多好地重构输入数据 x。它鼓励模型学习潜在表示 z,这些表示保留了足够的信息以重建 x。这是 VAE 的“自编码”部分。logpθ(x∣z) 的具体形式取决于数据类型:
- 对于二值数据(例如,黑白图像),pθ(x∣z) 通常建模为伯努利分布的乘积。最大化 logpθ(x∣z) 对应于最小化输入 x 和重构输出 x^=decoder(z) 之间的二元交叉熵 (BCE) 损失。
- 对于实值数据(例如,像素强度归一化到 [0,1] 的图像或连续信号),pθ(x∣z) 通常建模为高斯分布 N(x∣μθ(z),σ2I)。如果方差 σ2 是固定的,最大化此项等同于最小化 x 和解码器平均输出 μθ(z) 之间的均方误差 (MSE)。
-
KL 散度正则项: DKL(qϕ(z∣x)∣∣p(z))
此项充当潜在空间上的正则项。它衡量近似后验分布 qϕ(z∣x)(由编码器针对给定输入 x 生成)与潜在变量的先验分布 p(z) 之间的不相似性。先验 p(z) 通常选择为简单、固定的分布,最常见的是标准多元高斯分布 N(0,I)。
通过最小化此 KL 散度(请注意 ELBO 公式中的负号,这意味着我们通过最大化 ELBO 项 −DKL 来有效最小化 DKL),我们鼓励编码器生成潜在分布 qϕ(z∣x),使其平均而言接近先验 p(z)。这有若干好处:
- 平滑性和连续性: 它有助于组织潜在空间,使其更连续,并减少出现“空洞”或不相连区域的可能性。这对 VAE 的生成能力很重要,因为我们希望能够采样 z∼p(z) 并生成新颖、连贯的数据。
- 正则化: 它阻止编码器学习过于复杂或“作弊”的后验,这些后验可能只是简单地记住 z 中的输入数据。
ELBO 包含两个主要项。第一项是期望重构对数似然,它促使模型准确重构数据。第二项是 KL 散度项,通过促使近似后验 qϕ(z∣x) 接近预设先验 p(z) 来正则化潜在空间。
实践中的 KL 散度项
先验 p(z) 和近似后验 qϕ(z∣x) 的常见选择都是多元高斯分布。
设 p(z)=N(z∣0,I),一个均值为零、协方差矩阵为单位矩阵的标准高斯分布。
设近似后验 qϕ(z∣x) 也是一个高斯分布,但其均值为 μϕ(x),且协方差矩阵为对角矩阵 diag(σϕ,12(x),...,σϕ,J2(x)),其中 J 是潜在空间的维度。编码器网络将为每个输入 x 输出参数 μϕ(x) 和 log(σϕ2(x))(或直接是 σϕ(x))。
对于这些选择,KL 散度 DKL(qϕ(z∣x)∣∣p(z)) 有一个便捷的解析解:
DKL(N(μϕ(x),diag(σϕ2(x)))∣∣N(0,I))=21j=1∑J(μϕ,j(x)2+σϕ,j2(x)−log(σϕ,j2(x))−1)
这个闭式表达式可以直接纳入 VAE 的损失函数中,并通过梯度下降进行优化。重参数化技巧(我们将在下一节中讨论)对于通过期望 Eqϕ(z∣x) 所涉及的采样过程反向传播梯度至关重要。
总之,ELBO 为训练 VAE 提供了可计算的目标函数。它精妙地平衡了准确数据重构的需求与正则化、平滑且适合生成潜在空间的需求。通过最大化 ELBO,我们同时改进我们对数据 p(x) 的模型,并优化我们对真实、难以计算的后验 p(z∣x) 的近似 qϕ(z∣x)。理解此公式对于理解 VAE 如何学习以及开发更进阶的 VAE 架构和技术而言是基础性的。