训练变分自编码器(VAE)需要优化一个独特的损失函数 (loss function)。与主要关注最小化重建误差的标准自编码器不同,VAE损失函数具有双重目标:精确重建输入数据,并对潜在空间施加特定结构。这种结构使得VAE能够成为有效的生成模型。
VAE的总损失,通常表示为 LVAE,是两个不同项的总和:
LVAE=L重建+LKL
让我们分别分析这些组成部分,以理解它们的作用以及它们如何促成VAE的学习过程。
重建损失:确保数据保真
VAE损失的第一部分,L重建,衡量解码器从潜在表示 z 重建原始输入 x 的能力。潜在向量 (vector) z 是从编码器学习到的分布 q(z∣x) 中采样的。然后解码器试图生成一个输出 x^,使其尽可能接近 x。
重建损失的具体公式取决于输入数据的性质:
-
对于连续数据,例如像素值归一化 (normalization)在0到1之间的图像,或一般的实值特征,均方误差 (MSE) 是一个常用选择。它计算原始输入像素(或特征)与重建像素之间的平均平方差。
LMSE=N1∑i=1N(xi−x^i)2
其中 N 是输入中的特征或像素数量。
-
对于二值数据,或输入为概率的数据(例如像素为0或1的黑白图像,或解码器最后一层sigmoid激活的输出),通常使用二元交叉熵 (BCE)。BCE衡量两个概率分布之间的差异,在此是原始输入分布和重建输入分布之间的差异。
LBCE=−∑i=1N[xilog(x^i)+(1−xi)log(1−x^i)]
同样,N 是特征/像素的数量。这种损失促使解码器的输出 x^i 在每个维度上都接近 xi。
重建损失促使VAE学习编码器和解码器对,使其在潜在空间的限制内尽可能多地保留输入数据的信息。如果没有这一项,VAE就没有动机去学习有意义的压缩或生成可识别的数据。
KL散度:构建潜在空间
第二个组成部分,LKL,是Kullback-Leibler (KL) 散度项。这一项真正区分了VAE与标准自编码器,并且对其生成能力非常重要。它充当潜在空间的正则化 (regularization)器。
回想一下,VAE编码器不仅仅在潜在空间中输出一个点;相反,对于每个输入 x,它输出定义概率分布 q(z∣x) 的参数 (parameter)(通常是均值 μ(x) 和对数方差 log(σ2(x)))。这个分布通常是高斯分布:q(z∣x)=N(z;μ(x),diag(σ2(x)))。
KL散度项衡量这个学习到的分布 q(z∣x) 与所选先验分布 p(z) 的差异程度。先验 p(z) 通常是一个标准正态分布,N(0,I),意味着一个高斯分布,其每个潜在维度均值为零,方差为一,且维度之间没有相关性。
LKL=DKL(q(z∣x)∣∣p(z))
对于一个编码器,其为每个潜在维度 j(从 1 到 J)输出 μj 和 logvarj(对数方差),且先验 p(z)=N(0,I),KL散度可以计算为:
D_{KL}(q(z|x) || p(z)) = \frac{1}{2} \sum_{j=1}^{J} (\exp(\text{log_var}_j) + \mu_j^2 - 1 - \text{log_var}_j)
最小化这个KL散度项促使编码器产生接近标准正态先验 p(z) 的分布 q(z∣x)。这带来几个重要结果:
- 连续性: 它迫使不同输入的编码在潜在空间的原点附近形成“聚类”,并且方差接近一。这有助于确保潜在空间是连续的,没有大的“间隙”。如果潜在空间是连续的,那么潜在向量 (vector) z 的微小变化会导致生成输出 x^ 的微小平滑变化。
- 正则化: 它阻止编码器学习一种“恒等式”函数,使得每个输入都被映射到潜在空间中一个非常特定、孤立的点(即,使 σ2 非常小)。这会使重建变得容易,但会产生一个不佳、无序的潜在空间,不适合生成。
- 用于生成的采样: 通过强制 q(z∣x) 趋近 p(z),VAE学习了一个潜在空间,其中从简单先验 p(z)(例如 N(0,I))中采样的点很可能被解码为看起来真实的数据。这是因为解码器已在平均而言来自与 p(z) 相似分布的潜在向量上进行训练。
本质上,KL散度项确保潜在空间表现良好,使得从 p(z) 中采样并生成新的数据点成为可能。
权衡:重建与正则化 (regularization)
VAE训练过程涉及最小化这两个损失项的总和。这产生了一种基本的矛盾:
- 重建损失希望编码器尽可能多地保留信息,可能会在潜在空间中广泛分散 q(z∣x) 分布,以使每个输入不同。
- KL散度希望“压缩”所有 q(z∣x) 分布,使其看起来像标准正态先验 N(0,I),这可能意味着丢失一些特定于单个输入的信息。
优化器的任务是为编码器和解码器找到一组权重 (weight),以在这两个相互竞争的目标之间取得平衡。一个成功的VAE会学习:
- 将足够的信息编码到潜在分布 q(z∣x) 中(特别是其均值 μ(x) 和方差 σ2(x)),以实现良好的重建。
- 使这些潜在分布 q(z∣x) 足够接近先验 p(z),从而使潜在空间保持结构化,并适合通过从 p(z) 中抽取 z 来生成新样本。
这种平衡使得VAE不仅能够重建数据,还能生成新的、合理的数据样本,并学习一个平滑、有意义的潜在空间,其中相似的输入被映射到相邻区域。
有时会引入一个系数 β,以调节KL散度项在总损失中的权重:
LVAE=L重建+β⋅DKL(q(z∣x)∣∣p(z))
这是 β-VAE 的公式。当 β>1 时,会更强调KL项,这可以导致更解耦的潜在表示(其中每个潜在维度对应数据中不同、可解释的变异因素),但可能以牺牲重建质量为代价。当 β=1 时,我们得到标准VAE损失。
理解这个复合损失函数 (loss function)对于掌握VAE如何学习其结构化潜在空间以及执行生成任务非常重要。通过仔细平衡重建输入的需求与组织潜在空间的需求,VAE为表示学习和数据生成提供了一个强大的框架。