使GAN训练更稳定可以通过修改损失函数 (loss function)的基本距离度量(例如在WGAN中)或采用正则化 (regularization)技术(例如谱归一化 (normalization))来实现。相对论生成对抗网络 (GAN)提出了一种不同的方法来提升稳定性,它从根本上改变了判别器需要预测的内容。
在标准GAN的表达中,判别器D试图估计给定输入x为真实的绝对概率。其输出D(x)通常被解读为P(x 为真实)。生成器G则被训练来生成样本G(z),使其最大化这个概率D(G(z))。
相对论GANs认为,判别器预测给定真实数据比随机采样的虚假数据更具真实性的相对概率,可能更有效且更稳定。判别器的任务不再是输出一个绝对分数,而是变为比较性的。
设C(x)表示判别器在最终激活函数 (activation function)(例如sigmoid)之前的输出。在标准GAN (SGAN) 中,判别器损失包含形如log(σ(C(xreal)))和log(1−σ(C(xfake)))的项,其中σ是sigmoid函数。
相对平均生成对抗网络 (GAN) (RaGAN)
相对平均生成对抗网络 (RaGAN) 是一个特别有效的变体。RaGAN并非比较单个真实样本与单个虚假样本,而是将一个样本(真实或虚假)与其对立分布中样本的平均评估进行比较。
其核心思想在RaSGAN(相对平均标准GAN)的损失函数 (loss function)中得到了形式化。判别器D被训练来最大化:
LDRaSGAN=−Exreal∼Preal[log(σreal)]−Exfake∼Pfake[log(1−σfake)]
其中:
- σreal=σ(C(xreal)−Exfake∼Pfake[C(xfake)])
- σfake=σ(C(xfake)−Exreal∼Preal[C(xreal)])
这里,Exfake∼Pfake[C(xfake)]是批次中虚假样本的判别器平均输出,而Exreal∼Preal[C(xreal)]是批次中真实样本的判别器平均输出。判别器正在学习使C(xreal)大于平均C(xfake),并使C(xfake)小于平均C(xreal)。
生成器G被训练来最小化其相反的目标:
LGRaSGAN=−Exfake∼Pfake[log(σfake)]−Exreal∼Preal[log(1−σreal)]
请留意这种对称性。生成器既能从增加其生成样本相对于平均真实样本的感知真实性(σfake)中获益,也能从降低真实样本相对于平均虚假样本的感知真实性(1−σreal)中获益。这种结构为生成器提供了基于真实和虚假样本的梯度,这有助于更稳定的学习。
相对论生成对抗网络 (GAN)的优点
- 更高的稳定性: 通过将真实和虚假批次之间的比较直接纳入损失函数 (loss function),RaGAN与标准GAN目标相比,通常能带来更稳定的训练动态,减少如模式崩溃等问题。
- 更快的收敛: 实验结果表明,RaGAN比标准GANs甚至有时比WGAN-GP能更快收敛。
- 更高的样本质量: 这种相对比较能更有效地引导生成器,可能带来更高视觉逼真度的生成样本。
实现方面的考量
实现RaGAN涉及修改损失计算:
- 计算真实批次的判别器输出C(xreal)和虚假批次的判别器输出C(xfake)。
- 计算这些输出在其各自批次中的平均值:真实平均值=平均(C(xreal)) 和 虚假平均值=平均(C(xfake))。
- 计算相对论差异:C(xreal)−虚假平均值 和 C(xfake)−真实平均值。
- 应用sigmoid函数,并根据上述RaSGAN公式,为判别器和生成器的更新计算二元交叉熵损失。
相对论生成对抗网络 (GAN),特别是RaGAN,为GAN训练提供了一种不同的方法。通过将判别器的任务从绝对真实性评估转向相对真实性评估,它们提供了一种实用方式来达到更稳定有效的训练,为构建高级生成模型增添了又一项有价值的技术。虽然WGAN-GP或谱归一化 (normalization)等技术通过距离度量或正则化 (regularization)来处理稳定性问题,但RaGAN改变了判别器与生成器对抗游戏的基本目标本身。