增强变分自编码器(VAEs)的推断能力有多种方法。近似后验分布 qϕ(z∣x) 的质量对 VAE 性能有很大作用。虽然摊销推断提供了效率,但像重要性加权自编码器(IWAEs)这样的方法提供了一种实现更紧密证据下界(ELBO)和可能更准确的后验近似的方式。一个 IWAE 的实现及其作用分析将被呈现。其他进阶推断策略的考虑因素也将被讨论。
理解 IWAE 目标函数
请回顾,标准 VAE 最大化 ELBO:
LELBO(x)=Eqϕ(z∣x)[logpθ(x∣z)]−DKL(qϕ(z∣x)∣∣p(z))
Burda 等人(2015)提出的 IWAE,通过为每个数据点 x 使用来自近似后验分布 qϕ(z∣x) 的多个样本,为对数边缘似然 logpθ(x) 提供了一个更紧密的下界。IWAE 目标函数,通常记作 LK,为:
LK(x)=Ez1,...,zK∼qϕ(z∣x)[log(K1k=1∑Kqϕ(zk∣x)pθ(x,zk))]
此处 zk 是从 qϕ(z∣x) 中抽取的 K 个独立样本。这可被重写为更便于实现的形式:
LK(x)=Ez1,...,zK∼qϕ(z∣x)[log(K1k=1∑Kexp(logpθ(x∣zk)+logp(zk)−logqϕ(zk∣x)))]
请注意,当 K=1 时,L1 会恢复为标准 ELBO(这与实际中期望的处理方式存在细微的差别,但对于优化来说,它们非常相似)。随着 K→∞,LK→logpθ(x)。
实现 IWAE
我们来概述将标准 VAE 实现(你可能在第2章的实践练习中已有此实现)修改为 IWAE 的步骤。我们假设你有一个编码器,它为 qϕ(z∣x) 输出参数 (parameter)(例如,均值 μϕ(x) 和对数方差 logσϕ2(x)),以及一个解码器,它对 pθ(x∣z) 进行建模。
1. 采样多个潜在变量
对于小批量中的每个输入 x,你不需要只从 qϕ(z∣x) 中抽取一个样本 z,而是需要抽取 K 个样本。如果你的编码器为高斯 qϕ(z∣x) 输出 μϕ(x) 和 σϕ(x),则重参数化技巧将被应用 K 次:
zk=μϕ(x)+σϕ(x)⊙ϵk,这里 ϵk∼N(0,I),对于 k=1,…,K。
在张量操作方面(例如,在 PyTorch 或 TensorFlow 中):
- 如果你的编码器输出形状为
(batch_size, latent_dim) 的 mu 和 logvar。
- 你需要将它们扩展到
(batch_size, K, latent_dim) 的形状,或者以允许每个输入 K 个样本的方式处理它们。
- 生成形状为
(batch_size, K, latent_dim) 的 eps。
- 计算形状为
(batch_size, K, latent_dim) 的 z_samples。
# 伪代码/PyTorch 风格的示例
# mu, logvar 的形状: (batch_size, latent_dim)
# K: 重要性样本的数量
# 为 K 个样本扩展 mu 和 logvar
mu_expanded = mu.unsqueeze(1).expand(-1, K, -1) # (batch_size, K, latent_dim)
logvar_expanded = logvar.unsqueeze(1).expand(-1, K, -1) # (batch_size, K, latent_dim)
std_expanded = torch.exp(0.5 * logvar_expanded)
# 采样 epsilon
epsilon = torch.randn_like(std_expanded) # (batch_size, K, latent_dim)
# 为每个输入生成 K 个潜在样本
z_samples = mu_expanded + std_expanded * epsilon # (batch_size, K, latent_dim)
2. 计算每个样本的对数权重 (weight)
对于每个样本 zk,我们需要计算其未归一化 (normalization)的对数重要性权重 wk′:
logwk′=logpθ(x∣zk)+logp(zk)−logqϕ(zk∣x)
- logpθ(x∣zk): 重构对数似然。这涉及到将每个 zk 通过解码器,以获取 pθ(x∣zk) 的参数,然后计算 x 的对数概率。如果 x 为解码器进行了形状调整,请确保它与 K 个样本对齐 (alignment)。例如,如果 x 的形状为
(batch_size, data_dim),则在计算每个样本的重构损失时,可能需要将其扩展为 (batch_size, K, data_dim) 以匹配 z_samples。
- logp(zk): 先验对数概率。通常,p(z)=N(0,I),因此这很容易计算。
- logqϕ(zk∣x): 在近似后验 qϕ(z∣x) 下的对数概率。
这些分量将得到一个形状为 (batch_size, K) 的对数权重张量,例如 log_w_prime。
3. 使用 Log-Sum-Exp 进行平均
IWAE 目标函数涉及 log(K1∑kexp(logwk′))。直接对指数项求和可能导致数值下溢或上溢。log-sum-exp (LSE) 技巧在这里非常重要:
log(k=1∑Kexp(ak))=α+log(k=1∑Kexp(ak−α))
这里 α=maxkak。
单个数据点 x 的 IWAE 损失为:
LK(x)=LSE(logw1′,…,logwK′)−logK
小批量的最终损失是批次中所有 x 的 LK(x) 的平均值。
# 伪代码/PyTorch 风格的 IWAE 损失项示例
# log_p_x_given_z: (batch_size, K),每个样本的重构对数似然
# log_p_z: (batch_size, K),每个样本的先验对数概率
# log_q_z_given_x: (batch_size, K),每个样本的近似后验对数概率
log_w_prime = log_p_x_given_z + log_p_z - log_q_z_given_x # (batch_size, K)
# 为了数值稳定性进行 Log-sum-exp
log_sum_exp_w = torch.logsumexp(log_w_prime, dim=1) # (batch_size,)
# 每个数据点的 IWAE 目标函数
iwae_elbo_per_sample = log_sum_exp_w - torch.log(torch.tensor(K, dtype=torch.float32))
# 在批次上求平均
batch_iwae_elbo = torch.mean(iwae_elbo_per_sample)
# 需要最小化的损失是 -batch_iwae_elbo
loss = -batch_iwae_elbo
关于维度: 仔细管理张量维度。当你将 zk(形状为 (batch_size, K, latent_dim))传递给解码器时,它可能会将其处理为 (batch_size * K, latent_dim)。然后,输出的重构结果将是 (batch_size * K, data_dim)。你需要将 x 的形状调整以匹配这一点,从而计算 logpθ(x∣zk),然后将得到的对数似然调整回 (batch_size, K)。
实验与分析
一旦你实现了 IWAE,是时候进行实验了:
-
改变 K 值:使用不同的 K 值(例如,K=1,5,10,50)训练 IWAE 模型。
- 标准 VAE 作为基准:K=1 的情况有效地模拟了你的标准 VAE(尽管如前所述,估计器在理论上存在细微的差异,但实践中它常被用作 VAE 的基准)。
- 监控边界:在验证集上绘制报告的 LK 值。你会发现 LK 通常随 K 增加,这表明边界变得更紧密。
- 重构质量:直观检查重构结果。随着 K 增加,它们是否更清晰?如果合适,用 MSE 进行量化 (quantization),但请注意 IWAE 优化的是 logp(x) 的下界,而不是直接的重构误差。
- 样本质量:从先验 p(z) 中生成样本并进行解码。生成样本的质量是否随 K 的增加而提高?
- 训练时间:注意随着 K 增加,每个 epoch 的训练时间也会增加。这是一个直接的权衡。
上图说明了一个典型趋势:随着 K 的增加,IWAE 下界(LK)得到提升(变得不那么负),但每个 epoch 的计算成本也随之增加。
-
活跃单元:如果你对解耦或表示学习(在第5章中会讲到)也感兴趣,请考察使用 IWAE(当 K>1 时)是否会影响潜在空间中“活跃单元”的数量,与标准 VAE 相比。有时,更紧密的边界可以阻止过早的 KL 消失。
-
后验坍缩:对于容易出现后验坍缩(即 qϕ(z∣x) 变得与 p(z) 非常相似,导致潜在变量不提供信息)的模型,使用 K>1 的 IWAE 是否有助于缓解这个问题?IWAE 提供的更准确的梯度可能提供更好的优化路径。
其他进阶推断方法的考虑
IWAE 侧重于通过多样本改进边界,而其他方法则修改 qϕ(z∣x) 的结构或推断过程本身:
-
结构化变分推断:
- 归一化 (normalization)流:在编码器中实现带有归一化流的 VAE 涉及将简单的 Gaussian 样本转换为更复杂的分布。你会在从 N(μϕ(x),σϕ2(x)) 进行初始采样后插入流层(例如,平面流、径向流,或更进阶的自回归 (autoregressive)流如 MAF 或 IAF)。主要挑战是计算这些变换的雅可比行列式的对数,这需要加到 logqϕ(z0∣x) 中以获得 logqϕ(zK∣x)(这里 z0 是基础样本,zK 是变换后的样本)。
- 自回归后验:不同于对角协方差 Gaussian,qϕ(z∣x)=∏iqϕ(zi∣z<i,x)。这要求编码器使用自回归网络(例如 RNN 或掩码自编码器)。
-
辅助变量(例如,半摊销 VI,分层 VAEs):
- 引入辅助变量 u 来扩展潜在空间,例如 qϕ(z,u∣x)=qϕ(u∣x)qϕ(z∣x,u)。这通常会产生表达能力更强的后验。实现上涉及为这些辅助变量设计推断网络,并修改 ELBO 以考虑它们。
- 对于半摊销 VI,你可能需要从一个摊销提议开始,执行几步优化(例如 SGD)来精炼每个数据点的 z。这计算量更大,但可以得到非常准确的后验。
-
对抗变分贝叶斯 (AVB):
- AVB 用一个对抗鉴别器替换了 ELBO 中的 KL 散度项,该鉴别器试图区分来自 qϕ(z∣x) 的样本和来自真实先验 p(z) 的样本。编码器随后被训练来欺骗这个鉴别器。这需要在你的 VAE 框架内设置一个 GAN 类似的训练循环。
实现这些进阶方法通常涉及对概率建模的更细致研究和仔细的网络架构设计。共同点是转向捕捉更复杂的后验结构,或获取模型证据的更好估计。
后续展望
这项关于 IWAE 的实践练习应能为你理解如何改进 VAE 推断提供扎实的基础。通过使得 ELBO 更紧密,IWAE 可以带来更好的生成模型和更忠实的潜在表示。增加的计算成本是一种权衡,但通常,一个适中的 K 值(例如 5-10)可以提供一个良好的平衡。
随着你的学习深入,请思考 IWAE 的原理(通过多个样本获得更好估计)或其他进阶方法的结构改进,如何可以结合或适用于你的特定 VAE 应用。批判性评估和改进推断机制的能力是进阶 VAE 开发的一个标志。