变分自编码器使用证据下界(ELBO)作为训练目标函数。这个目标函数包含KL散度项,它作为潜在空间的正则化 (regularization)项。重构损失项是训练中的一个主要组成部分。此项衡量了变分自编码器在将原始输入数据编码到潜在空间并解码回来后,重构原始输入数据的程度。
回顾ELBO公式:
L ELBO ( θ , ϕ ; x ) = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] − D K L ( q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ) \mathcal{L}_{\text{ELBO}}(\theta, \phi; x) = \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] - D_{KL}(q_\phi(z|x) || p(z)) L ELBO ( θ , ϕ ; x ) = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z )] − D K L ( q ϕ ( z ∣ x ) ∣∣ p ( z ))
第一个项,E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] \mathbb{E}_{q_\phi(z|x)}[\log p_\theta(x|z)] E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z )] ,是重构项 。它表示给定潜在变量 z z z 时数据 x x x 的期望对数似然,其中 z z z 从编码器定义的近似后验分布 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 中采样得到。最大化此项促使由 θ \theta θ 参数 (parameter)化的解码器网络,学习从潜在空间返回到数据空间的映射,从而准确地再现输入数据 x x x 。
对数似然与损失函数 (loss function)的关联
重构损失的具体形式直接取决于我们对由解码器建模的分布 p θ ( x ∣ z ) p_\theta(x|z) p θ ( x ∣ z ) 所做的假设。我们来考察两种最常见的情况:
二进制数据: 如果输入数据 x x x 包含二进制值(例如,黑白MNIST图像中的像素值,通常被视为 { 0 , 1 } \{0, 1\} { 0 , 1 } 中的值),我们通常将解码器的输出分布建模为独立伯努利分布的乘积。对于输入向量 (vector) x x x 的每个维度 i i i ,解码器输出一个概率 x ^ i \hat{x}_i x ^ i (该维度伯努利分布的参数 (parameter))。单个数据点 x x x 的对数似然为:
log p θ ( x ∣ z ) = ∑ i [ x i log x ^ i + ( 1 − x i ) log ( 1 − x ^ i ) ] \log p_\theta(x|z) = \sum_i [x_i \log \hat{x}_i + (1 - x_i) \log(1 - \hat{x}_i)] log p θ ( x ∣ z ) = i ∑ [ x i log x ^ i + ( 1 − x i ) log ( 1 − x ^ i )]
其中 x ^ = decoder θ ( z ) \hat{x} = \text{decoder}_\theta(z) x ^ = decoder θ ( z ) 。注意,这正是分类中常用的**二元交叉熵(BCE)**损失函数的负值,并对所有输入维度求和。因此,最大化此对数似然项等同于最小化原始输入 x x x 和重构输出 x ^ \hat{x} x ^ 之间的BCE损失。
连续数据: 如果输入数据 x x x 包含实数值(例如,归一化 (normalization)到 [ 0 , 1 ] [0, 1] [ 0 , 1 ] 或 R \mathbb{R} R 的像素强度),一个常见选择是将解码器的输出分布建模为各向同性高斯分布,其均值 μ = x ^ = decoder θ ( z ) \mu = \hat{x} = \text{decoder}_\theta(z) μ = x ^ = decoder θ ( z ) ,且方差 σ 2 \sigma^2 σ 2 固定。对数似然变为:
log p θ ( x ∣ z ) = ∑ i log ( 1 2 π σ 2 exp ( − ( x i − x ^ i ) 2 2 σ 2 ) ) \log p_\theta(x|z) = \sum_i \log \left( \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x_i - \hat{x}_i)^2}{2\sigma^2}\right) \right) log p θ ( x ∣ z ) = i ∑ log ( 2 π σ 2 1 exp ( − 2 σ 2 ( x i − x ^ i ) 2 ) )
简化此表达式得到:
log p θ ( x ∣ z ) = − 1 2 σ 2 ∑ i ( x i − x ^ i ) 2 − D 2 log ( 2 π σ 2 ) \log p_\theta(x|z) = -\frac{1}{2\sigma^2} \sum_i (x_i - \hat{x}_i)^2 - \frac{D}{2} \log(2\pi\sigma^2) log p θ ( x ∣ z ) = − 2 σ 2 1 i ∑ ( x i − x ^ i ) 2 − 2 D log ( 2 π σ 2 )
其中 D D D 是 x x x 的维度。当最大化ELBO(或最小化其负值)时,项 ∑ i ( x i − x ^ i ) 2 \sum_i (x_i - \hat{x}_i)^2 ∑ i ( x i − x ^ i ) 2 是与重构 x ^ \hat{x} x ^ 相关的核心部分。这正是输入 x x x 和重构 x ^ \hat{x} x ^ 之间的均方误差(MSE) 。缩放因子 1 / ( 2 σ 2 ) 1/(2\sigma^2) 1/ ( 2 σ 2 ) 和常数项通常可以被学习率吸收,或者在假定 σ \sigma σ 是常数(例如 σ = 1 \sigma=1 σ = 1 )的情况下被简单地忽略,因为它们不影响相对于网络参数 θ \theta θ 的最优位置。因此,在这些假设下最大化高斯对数似然对应于最小化MSE损失。
重构与正则化 (regularization)的权衡
重构项促使变分自编码器学习潜在表示 z z z ,从而可以忠实地恢复原始数据 x x x 。它确保解码器生成的输出接近输入。然而,如果这是唯一的项,变分自编码器可能只会学习一个恒等函数(如果潜在维度允许),或者以不适合生成的方式使潜在空间退化。
这就是与 D K L ( q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ) D_{KL}(q_\phi(z|x) || p(z)) D K L ( q ϕ ( z ∣ x ) ∣∣ p ( z )) 项的平衡变得重要的地方。KL散度项鼓励近似后验分布 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 保持接近先验分布 p ( z ) p(z) p ( z ) (通常是标准高斯分布 N ( 0 , I ) \mathcal{N}(0, I) N ( 0 , I ) )。这种正则化组织了潜在空间,使其更平滑,更适合采样新的点 z ∼ p ( z ) z \sim p(z) z ∼ p ( z ) 并将它们解码成合理的新数据样本 x ^ = decoder θ ( z ) \hat{x} = \text{decoder}_\theta(z) x ^ = decoder θ ( z ) 。
训练变分自编码器需要找到一个平衡点:
过分侧重重构(例如,重度加权重 (weight)构项或KL项较弱)会导致极好的重构效果,但可能产生无序的潜在空间,不利于生成。
过分侧重KL散度会强制编码分布 q ϕ ( z ∣ x ) q_\phi(z|x) q ϕ ( z ∣ x ) 趋向先验分布,可能牺牲重构的准确性,因为潜在编码会丢失每个输入 x x x 的特定信息。
这种平衡通常由模型架构和优化器隐式控制,或通过如 β \beta β -VAE 等技术显式控制,其中引入系数 β \beta β 来缩放KL项:L = E [ log p θ ( x ∣ z ) ] − β D K L ( q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ) \mathcal{L} = \mathbb{E}[\log p_\theta(x|z)] - \beta D_{KL}(q_\phi(z|x) || p(z)) L = E [ log p θ ( x ∣ z )] − β D K L ( q ϕ ( z ∣ x ) ∣∣ p ( z )) 。
实际应用中的实现
在使用TensorFlow或PyTorch等框架实现变分自编码器时,重构项直接转化为计算输入数据批次与解码器生成的相应重构数据批次之间的BCE损失或MSE损失。这个损失值随后被加到(或减去,取决于你是最大化ELBO还是最小化负ELBO)批次的计算KL散度上,构成用于反向传播 (backpropagation)的最终损失值。
例如,在PyTorch中,你可以这样计算:
# 假设 decoder_output 和 input_data 是张量批次
# 对于二进制数据(例如,MNIST)
reconstruction_loss = F.binary_cross_entropy(decoder_output, input_data, reduction='sum') / input_data.shape[0]
# 对于连续数据(例如,归一化图像)
reconstruction_loss = F.mse_loss(decoder_output, input_data, reduction='sum') / input_data.shape[0]
# 变分自编码器总损失(负ELBO)
total_loss = reconstruction_loss + kl_divergence
(注意:具体实现可能会对维度和批次元素进行求和或平均;保持一致性非常重要。)
理解重构损失项的作用和表现是有效训练变分自编码器的根本。它代表了目标函数的数据准确性方面,确保模型学会生成与输入数据分布相似的输出,同时与KL散度协同作用,组织潜在空间以完成生成任务。