生成模型的主要目标是学习观测数据 x x x 的潜在概率分布 p ( x ) p(x) p ( x ) 。变分自编码器通过引入隐变量 z z z 来达成此目的,这些隐变量是未观测变量,它们表示数据的内在结构或属性。我们假定数据 x x x 是通过生成过程 p h e t a ( x ∣ z ) p_ heta(x|z) p h e t a ( x ∣ z ) 由这些隐变量 z z z 生成的,其中 z z z 本身是从先验分布 p ( z ) p(z) p ( z ) 中采样的。观测值 x x x 的边际似然由以下公式给出:
p θ ( x ) = ∫ p θ ( x ∣ z ) p ( z ) d z p_\theta(x) = \int p_\theta(x|z) p(z) dz p θ ( x ) = ∫ p θ ( x ∣ z ) p ( z ) d z
我们目标是学习此生成模型(通常是一个神经网络 (neural network),即“解码器”)的参数 (parameter) θ \theta θ 。直接最大化 log p θ ( x ) \log p_\theta(x) log p θ ( x ) 很困难,因为这个积分通常难以计算,尤其当 p θ ( x ∣ z ) p_\theta(x|z) p θ ( x ∣ z ) 复杂(如深度神经网络)且 z z z 维度很高时。
此外,为了执行将数据编码为隐表示等任务,我们经常需要访问后验分布 p θ ( z ∣ x ) = p θ ( x ∣ z ) p ( z ) p θ ( x ) p_\theta(z|x) = \frac{p_\theta(x|z)p(z)}{p_\theta(x)} p θ ( z ∣ x ) = p θ ( x ) p θ ( x ∣ z ) p ( z ) 。这也难以计算,因为其分母 p θ ( x ) p_\theta(x) p θ ( x ) 难以计算。
这就是变分推断的作用所在。我们不计算真实后验 p θ ( z ∣ x ) p_\theta(z|x) p θ ( z ∣ x ) ,而是引入一个更简单、可计算的分布 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 对其进行近似。这个分布 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 通常被称为变分后验或推断网络(在VAE中,即“编码器”),它通常由 ϕ \phi ϕ 参数化(例如,另一个神经网络的参数)。我们的目标是使 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 尽可能接近真实后验 p θ ( z ∣ x ) p_\theta(z|x) p θ ( z ∣ x ) 。
我们从数据点 x x x 的对数似然 log p θ ( x ) \log p_\theta(x) log p θ ( x ) 开始。我们可以使用我们的近似后验 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 重写它:
log p θ ( x ) = E q ϕ ( z ∣ x ) [ log p θ ( x ) ] \log p_\theta(x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x)] log p θ ( x ) = E q ϕ ( z ∣ x ) [ log p θ ( x )]
之所以成立是因为 log p θ ( x ) \log p_\theta(x) log p θ ( x ) 不依赖于 z z z ,并且 ∫ q ϕ ( z ∣ x ) d z = 1 \int q_\phi(z|x) dz = 1 ∫ q ϕ ( z ∣ x ) d z = 1 。现在,我们使用条件概率的定义 p θ ( x ) = p θ ( x , z ) / p θ ( z ∣ x ) p_\theta(x) = p_\theta(x,z) / p_\theta(z|x) p θ ( x ) = p θ ( x , z ) / p θ ( z ∣ x ) :
log p θ ( x ) = E q ϕ ( z ∣ x ) [ log p θ ( x , z ) p θ ( z ∣ x ) ] \log p_\theta(x) = \mathbb{E}_{q_\phi(z|x)}\left[\log \frac{p_\theta(x,z)}{p_\theta(z|x)}\right] log p θ ( x ) = E q ϕ ( z ∣ x ) [ log p θ ( z ∣ x ) p θ ( x , z ) ]
接下来,我们在对数内部乘以和除以 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) ,这是一个常用技巧:
log p θ ( x ) = E q ϕ ( z ∣ x ) [ log ( p θ ( x , z ) q ϕ ( z ∣ x ) q ϕ ( z ∣ x ) p θ ( z ∣ x ) ) ] \log p_\theta(x) = \mathbb{E}_{q_\phi(z|x)}\left[\log \left(\frac{p_\theta(x,z)}{q_\phi(z|x)} \frac{q_\phi(z|x)}{p_\theta(z|x)}\right)\right] log p θ ( x ) = E q ϕ ( z ∣ x ) [ log ( q ϕ ( z ∣ x ) p θ ( x , z ) p θ ( z ∣ x ) q ϕ ( z ∣ x ) ) ]
利用性质 log ( a b ) = log a + log b \log(ab) = \log a + \log b log ( ab ) = log a + log b :
log p θ ( x ) = E q ϕ ( z ∣ x ) [ log p θ ( x , z ) q ϕ ( z ∣ x ) ] + E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p θ ( z ∣ x ) ] \log p_\theta(x) = \mathbb{E}_{q_\phi(z|x)}\left[\log \frac{p_\theta(x,z)}{q_\phi(z|x)}\right] + \mathbb{E}_{q_\phi(z|x)}\left[\log \frac{q_\phi(z|x)}{p_\theta(z|x)}\right] log p θ ( x ) = E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p θ ( x , z ) ] + E q ϕ ( z ∣ x ) [ log p θ ( z ∣ x ) q ϕ ( z ∣ x ) ]
我们审视这两个项。第二项是 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 和真实后验 p θ ( z ∣ x ) p_\theta(z|x) p θ ( z ∣ x ) 之间的Kullback-Leibler (KL) 散度:
E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p θ ( z ∣ x ) ] = D K L ( q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ∣ x ) ) \mathbb{E}_{q_\phi(z|x)}\left[\log \frac{q_\phi(z|x)}{p_\theta(z|x)}\right] = D_{KL}(q_\phi(z|x) || p_\theta(z|x)) E q ϕ ( z ∣ x ) [ log p θ ( z ∣ x ) q ϕ ( z ∣ x ) ] = D K L ( q ϕ ( z ∣ x ) ∣∣ p θ ( z ∣ x ))
KL散度总是非负的 (D K L ≥ 0 D_{KL} \ge 0 D K L ≥ 0 ),当且仅当 q ϕ ( z ∣ x ) = p θ ( z ∣ x ) q_\phi(z|x) = p_\theta(z|x) q ϕ ( z ∣ x ) = p θ ( z ∣ x ) 时为零。
第一项定义为证据下界 (ELBO),记作 L ( θ , ϕ ; x ) \mathcal{L}(\theta, \phi; x) L ( θ , ϕ ; x ) :
L ( θ , ϕ ; x ) = E q ϕ ( z ∣ x ) [ log p θ ( x , z ) q ϕ ( z ∣ x ) ] \mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}\left[\log \frac{p_\theta(x,z)}{q_\phi(z|x)}\right] L ( θ , ϕ ; x ) = E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p θ ( x , z ) ]
于是,我们得到了基本恒等式:
log p θ ( x ) = L ( θ , ϕ ; x ) + D K L ( q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ∣ x ) ) \log p_\theta(x) = \mathcal{L}(\theta, \phi; x) + D_{KL}(q_\phi(z|x) || p_\theta(z|x)) log p θ ( x ) = L ( θ , ϕ ; x ) + D K L ( q ϕ ( z ∣ x ) ∣∣ p θ ( z ∣ x ))
边际对数似然分解为ELBO以及近似后验与真实后验之间的KL散度。最大化ELBO是我们的可计算目标。
由于 D K L ( q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ∣ x ) ) ≥ 0 D_{KL}(q_\phi(z|x) || p_\theta(z|x)) \ge 0 D K L ( q ϕ ( z ∣ x ) ∣∣ p θ ( z ∣ x )) ≥ 0 ,因此有:
log p θ ( x ) ≥ L ( θ , ϕ ; x ) \log p_\theta(x) \ge \mathcal{L}(\theta, \phi; x) log p θ ( x ) ≥ L ( θ , ϕ ; x )
这就是为什么 L ( θ , ϕ ; x ) \mathcal{L}(\theta, \phi; x) L ( θ , ϕ ; x ) 被称为“证据下界”:它提供了数据对数似然(即“证据”)的下界。通过同时最大化ELBO,并优化生成模型参数 θ \theta θ 和变分参数 ϕ \phi ϕ ,我们实际上是在:
提高 log p θ ( x ) \log p_\theta(x) log p θ ( x ) 的下界,使其更高。
最小化KL散度 D K L ( q ϕ ( z ∣ x ) ∣ ∣ p θ ( z ∣ x ) ) D_{KL}(q_\phi(z|x) || p_\theta(z|x)) D K L ( q ϕ ( z ∣ x ) ∣∣ p θ ( z ∣ x )) ,使我们的近似 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 更接近真实后验 p θ ( z ∣ x ) p_\theta(z|x) p θ ( z ∣ x ) 。如果 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 变为一个完全近似,KL散度将变为零,ELBO将等于真实的对数似然。
ELBO可以重写成一种对VAE而言通常更直观的形式。从 L ( θ , ϕ ; x ) = E q ϕ ( z ∣ x ) [ log p θ ( x , z ) − log q ϕ ( z ∣ x ) ] \mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x,z) - \log q_\phi(z|x)] L ( θ , ϕ ; x ) = E q ϕ ( z ∣ x ) [ log p θ ( x , z ) − log q ϕ ( z ∣ x )] 开始并使用 p θ ( x , z ) = p θ ( x ∣ z ) p ( z ) p_\theta(x,z) = p_\theta(x|z)p(z) p θ ( x , z ) = p θ ( x ∣ z ) p ( z ) :
L ( θ , ϕ ; x ) = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) p ( z ) − log q ϕ ( z ∣ x ) ] \mathcal{L}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)p(z) - \log q_\phi(z|x)] L ( θ , ϕ ; x ) = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) p ( z ) − log q ϕ ( z ∣ x )]
= E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) + log p ( z ) − log q ϕ ( z ∣ x ) ] = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z) + \log p(z) - \log q_\phi(z|x)] = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) + log p ( z ) − log q ϕ ( z ∣ x )]
= E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] − E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) − log p ( z ) ] = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - \mathbb{E}_{q_\phi(z|x)}[\log q_\phi(z|x) - \log p(z)] = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z )] − E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) − log p ( z )]
第二项再次是一个KL散度:D K L ( q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ) D_{KL}(q_\phi(z|x) || p(z)) D K L ( q ϕ ( z ∣ x ) ∣∣ p ( z )) 。这给我们提供了VAE中最常见的ELBO形式:
L ( θ , ϕ ; x ) = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] ⏟ 重构项 − D K L ( q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ) ⏟ 正则化项 \mathcal{L}(\theta, \phi; x) = \underbrace{\mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)]}_{\text{重构项}} - \underbrace{D_{KL}(q_\phi(z|x) || p(z))}_{\text{正则化项}} L ( θ , ϕ ; x ) = 重构项 E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z )] − 正则化项 D K L ( q ϕ ( z ∣ x ) ∣∣ p ( z ))
我们来分析这两个组成部分:
重构项: E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z )] 。此项衡量了解码器 p θ ( x ∣ z ) p_\theta(x|z) p θ ( x ∣ z ) 能多好地重构输入数据 x x x ,给定从编码器近似 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 中采样的隐变量样本 z z z 。最大化此项鼓励 VAE 学习有意义的隐表示 z z z ,这些表示保留足够信息以重构 x x x 。例如,如果 p θ ( x ∣ z ) p_\theta(x|z) p θ ( x ∣ z ) 是一个高斯分布,此项将变为平方误差损失。如果 p θ ( x ∣ z ) p_\theta(x|z) p θ ( x ∣ z ) 是一个伯努利分布(用于二值数据),此项将变为二元交叉熵损失。
正则化 (regularization)项: D K L ( q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ) D_{KL}(q_\phi(z|x) || p(z)) D K L ( q ϕ ( z ∣ x ) ∣∣ p ( z )) 。此项充当正则化器。它衡量了近似后验分布 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) (由编码器对给定 x x x 输出的)与隐变量的先验分布 p ( z ) p(z) p ( z ) 之间的散度。先验 p ( z ) p(z) p ( z ) 通常选择为简单的分布,例如标准多元高斯分布 N ( 0 , I ) \mathcal{N}(0, I) N ( 0 , I ) 。此KL项鼓励编码器将隐表示 z z z 分布得与先验类似。这种正则化对于确保隐空间结构良好且可用于生成新数据非常重要。没有它,编码器可能会学习生成 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 分布,这些分布对不同的 x x x 彼此相距很远,从而导致非平滑或“有空隙的”隐空间。
因此,VAE 由以下部分组成:
一个编码器网络 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) ,它接收数据 x x x 并输出 z z z 分布的参数(例如,如果 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 是高斯分布,则为均值和方差)。其参数为 ϕ \phi ϕ 。
一个解码器网络 p θ ( x ∣ z ) p_\theta(x|z) p θ ( x ∣ z ) ,它接收隐变量样本 z z z 并输出 x x x 分布的参数。其参数为 θ \theta θ 。
一个选定的隐变量先验 p ( z ) p(z) p ( z ) ,通常是 N ( 0 , I ) \mathcal{N}(0,I) N ( 0 , I ) 。
训练过程包括同时优化ELBO L ( θ , ϕ ; x ) \mathcal{L}(\theta, \phi; x) L ( θ , ϕ ; x ) ,并优化 θ \theta θ 和 ϕ \phi ϕ ,使用诸如随机梯度上升的技术。这个推导为VAE目标函数提供了理论依据,它平衡了数据重构与隐空间正则化。这种平衡使得VAE能够学习丰富、有结构的表示并生成新的数据样本。变分自编码器和变分推断中的“变分”一词指的是变分法中的方法,其中我们优化泛函(函数的函数),在此例中,是在由 ϕ \phi ϕ 参数化的一系列分布中找到最佳的 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 。
在后续章节中,我们将探究这个目标的每个组成部分,优化它的实际操作(如重参数化技巧),以及这种公式表达的含义。