虽然对变分自编码器(VAE)目标函数进行的修改,例如β-VAE、FactorVAE和全相关性VAE(TCVAE)中的修改,提供了有助于促进解耦的有益方式,但对抗训练提供了一种独特且通常更直接的方法。与仅依赖VAE损失函数 (loss function)中的惩罚项(例如,像KL散度或全相关性这样的信息论相关项)不同,对抗方法引入了一个辅助网络,一个“对抗网络”或“判别器”。这个对抗网络经过训练,能够识别所学表示中特定形式的纠缠或不良属性。VAE的编码器反过来进行训练,以生成能够“欺骗”这个对抗网络的表示,从而使表示趋向于期望的解耦结构。
这个过程形成一种动态互动,常被视为一个极大极小博弈,编码器会适应对抗网络不断提升的能力。下面介绍这种方法如何用于促进解耦表示的生成。
解耦的对抗极大极小博弈
基本设置包括至少两个组成部分:
-
VAE编码器 (E):这个网络是VAE的一部分,将输入数据 x 映射到潜编码 z=E(x)。它的目标有两方面:
- 最小化标准VAE目标函数(即,最大化证据下界ELBO),这包括精确的数据重建和遵循先验分布 p(z)。
- 生成潜编码 z,使其能够在关于解耦属性方面欺骗辅助判别网络 (Dadv)。
-
对抗网络/判别器 (Dadv):这个网络经过训练来执行一个能显现潜编码 z 中纠缠的任务。例如,它可能试图区分VAE的聚合后验分布 q(z)=Epdata(x)[q(z∣x)] 与潜在维度统计独立的分布。
编码器和判别器迭代训练。判别器学习如何更好地完成其任务,编码器学习生成使判别器任务更难的表示。
FactorVAE的判别器:全相关性的对抗方法
你已经遇到过一个对抗训练的实例,在FactorVAE的背景下。FactorVAE旨在最小化潜编码 z 各维度间的全相关性 (TC),这是衡量这些维度相互依赖程度的一个量度。直接从样本估算TC可能存在困难。FactorVAE提出为此目的使用判别器 DTC(即我们的 Dadv)。
-
判别器 DTC 经过训练,用于区分:
- 从VAE的聚合后验分布 q(z) 中抽取的样本 z。
- 从 q(z) 的“打乱”版本中抽取的样本 z′,记作 qshuff(z)=∏jq(zj),其中每个维度 zj 都是从其在 q(z) 下的边缘分布中独立采样的。这个 qshuff(z) 表示一种阶乘分布,与 q(z) 的边缘分布相匹配。
-
此判别器的损失(例如,二元交叉熵)可能为:
LDTC=−(Ez∼q(z)[logDTC(z)]+Ez′∼qshuff(z)[log(1−DTC(z′))])
此处,DTC(z) 是 z 来自“真实”(可能纠缠的)q(z) 的概率,而 1−DTC(z′) 是 z′ 被正确识别为来自打乱的(维度独立的)分布的概率。
-
VAE编码器随后进行训练,不是以经典GAN中最大化 DTC(z) 的方式来“欺骗” DTC,而是为了最小化从 DTC 输出得出的全相关性估算值。例如,添加到VAE目标函数中的TC项可以近似表示为:
TC(z)≈Ez∼q(z)[logDTC(z)−log(1−DTC(z))]
最小化此项会使 q(z) 更接近 qshuff(z),从而减少依赖性并促进解耦。编码器的目标函数变为 LVAE+γ⋅TCestimated_by_DTC(z),其中 γ 是一个超参数 (parameter) (hyperparameter)。
以下图表展示了此设置:
一个促进解耦的对抗设置。VAE(编码器、解码器)将输入数据 x 处理为潜编码 z,然后用于重建 x^。潜编码 z 也由一个对抗网络 (Dadv) 评估。在这个类似FactorVAE的例子中,Dadv 比较聚合后验 q(z) 的样本和置换版本 zperm(代表 qshuff(z))的样本,以估算像全相关性这样的依赖关系。这个估算形成了一个对抗信号,引导编码器生成具有更少维度间依赖的潜编码,同时兼顾标准VAE目标。
更广泛的解耦对抗策略
FactorVAE方法只是利用对抗训练的一种方式。其他策略包括:
-
将聚合后验匹配到因子先验(AAE风格):
对抗自编码器(AAE)主要目标是使用判别器将聚合后验 q(z) 匹配到选定的先验 p(z)(例如,各向同性高斯分布 N(0,I))。此判别器经过训练,用于区分 q(z) 中的样本与 p(z) 中的样本。编码器反过来试图使 q(z) 分布与 p(z) 无法区分。如果 p(z) 选为因子分布(即 p(z)=∏jp(zj)),这种对抗性强制匹配会间接促使 z 的维度彼此独立,这是解耦的一个特点。这是一种替代方案,避免仅仅依赖VAE目标函数中的KL散度项 DKL(q(z∣x)∣∣p(z)) 来调整 q(z) 的形状。
-
带有因子监督的定向解耦:
如果某些潜在变化因子 (ys) 的真实标签可用(即使只是部分数据),对抗训练可以用于更具针对性的解耦。例如:
- 假设我们希望潜在维度 zk 专门表示因子 ys。
- 可以训练一个预测网络从 zk 预测 ys: P(ys∣zk)。编码器将被鼓励使 zk 对 ys 具有信息。
- 一个对抗网络 Dadv 尝试从剩余潜在维度 z∖k={zj∣j=k} 预测 ys。
- 编码器的目标将包含一个项,用于最大化对抗网络从 z∖k 预测 ys 时的误差,从而有效尝试从 z∖k 中移除关于 ys 的信息。
此设置旨在将关于 ys 的信息隔离到 zk 中。编码器总损失成为VAE损失、P(ys∣zk) 的监督损失以及与 Dadv 相关的对抗损失的加权和。
-
对抗信息掩蔽:
与上述类似,可以训练一个对抗网络从潜在维度的子集预测特定属性。编码器随后经过训练,使对抗网络无法成功,从而“掩蔽”掉这些潜在维度中的信息,并有望将其集中到别处。
对抗解耦的优势
采用对抗训练进行解耦有以下几点优势:
- 直接作用:相比于仅依赖KL散度等正则化 (regularization)项(其影响可能更分散),对抗网络能对特定解耦属性施加更直接、更明确的作用。
- 解耦定义的灵活性:对抗网络的设计(其架构、目标函数、其试图预测或区分的内容)允许编码不同的解耦操作定义。这比固定的数学正则化器更具适应性。
- 潜在更强的解耦效果:当稳定时,对抗作用有时能使表示在定量指标上展现出更好的解耦得分,因为模型正积极抵抗一个试图发现纠缠的组件。
挑战与实践考量
尽管对抗方法在解耦方面有其潜在优势,但也存在显著挑战:
- 训练不稳定性:这是对抗学习中的一个常见问题。平衡VAE(编码器/解码器)和对抗网络的训练非常重要。如果对抗网络过快变得过于熟练,编码器可能无法获得有益的学习信号。反之,弱的对抗网络会提供不足的作用。细致调整学习率、网络容量和更新计划非常必要。
- 超参数 (parameter) (hyperparameter)敏感性:引入对抗组件会增加更多的超参数需要调整,包括对抗网络的架构、对抗损失项相对于VAE目标函数的权重 (weight),以及优化器。
- 定义“合适”的对抗网络:这种方法的成功取决于设计一个能准确反映所需解耦属性的对抗任务。设计不当的对抗网络可能导致意想不到的结果,或无法促进有意义的解耦。例如,如果真实因子不是高斯分布,简单地将 q(z) 匹配到 N(0,I) 可能不足够。
- 计算成本:训练一个额外的网络(对抗网络)会增加每次迭代的计算负担。
- 模式崩塌或退化解:与其他生成对抗设置一样,存在编码器找到一些琐碎解来欺骗对抗网络的风险,这些解不对应于良好的、解耦的表示。
总而言之,对抗训练提供了一个强大而灵活的工具集,用于促进VAE中的解耦表示。通过引入一个主动寻找并惩罚纠缠的学习组件,这些方法可以提供一条更直接的途径,以实现结构化的潜在空间。然而,它们的成功应用需要细致的对抗博弈设计、精确的超参数调整以及管理训练稳定性的有力策略,这使得它们成为追求可解释和可控生成模型中的一种先进技术。