尽管弗雷歇起始距离(FID)等指标提供一个分数来总结GAN的性能,但它们将两个不同方面混为一谈:生成样本的逼真度(质量或真实性)以及这些样本相对于真实数据分布的多样性(种类或覆盖范围)。低FID分数是理想的,但它没有告诉我们模型表现不佳的原因。是因为它生成了不真实的图像(低逼真度),还是未能捕获数据的全部种类(低多样性,可能出现模式崩溃)?
为了获得更具体的认识,我们可以调整准确率和召回率这些观念,这些观念传统上用于信息检索和分类,现在可以用于评估生成模型。在GAN的背景下,这些指标有助于将样本质量评估与分布覆盖评估分开。
理解GAN的准确率与召回率
设想比较生成样本的分布Pg与真实样本的分布Pr。
- 准确率衡量生成样本的逼真度。它回答了这个问题:“在生成器产生的样本中,有多少比例是目标分布中真实或可信的样本?”高准确率表示GAN生成的“看起来假”或带有伪影的样本很少。它与质量方面有关。
- 召回率衡量生成样本相对于真实数据的多样性或覆盖范围。它回答:“在真实数据集中存在的所有类型样本中,生成器实际能产生多少比例?”高召回率表示生成器捕获了真实数据分布中存在的大部分变体和模式。它与多样性方面有关。
考虑以下情况:
- 高准确率,低召回率: GAN生成了非常真实的样本(高质量),但只覆盖了真实数据种类的一小部分(例如,模式崩溃)。
- 低准确率,高召回率: GAN生成了多种样本,覆盖了真实数据的大部分模式,但其中许多样本不真实或包含伪影。
- 低准确率,低召回率: 最差的情况。GAN产生了低质量的样本,并且未能捕获真实数据的多样性。
- 高准确率,高召回率: 理想的情况。GAN产生了高质量、真实的样本,覆盖了真实数据分布的全部多样性。
计算的实用方法
计算连续、高维分布的准确率和召回率需要一种实用方法。一种常见的方法,由Kynkäänniemi等人(2019)提出,涉及分析样本在合适特征空间中的接近程度:
- 嵌入: 将大量真实样本 (Nr) 和生成样本 (Ng) 嵌入特征空间。这通常通过使用预训练网络(如Inception V3)中间层的激活来完成,类似于FID的计算方式。令这些特征向量,真实样本为 xr∈Rd,生成样本为 xg∈Rd。
- 最近邻: 对于每个真实样本 xr,在特征空间中基于欧几里得距离,找到它在所有其他真实样本中的第 k 个最近邻。令到此邻居的距离为 d(xr,NNk(xr))。同样,对于每个生成样本 xg,找到它在所有其他生成样本中的第 k 个最近邻的距离,d(xg,NNk(xg))。k 的值(例如 k=3 或 k=5)是一个超参数。这些距离定义了每个点周围的近似局部密度估计。
- 准确率计算: 对于每个生成样本 xg,判断它是否落在真实数据的“流形”内。这是通过检查是否存在至少一个真实样本 xr,使得距离 ∥xg−xr∥2≤d(xr,NNk(xr)) 来完成的。简单来说,我们检查生成样本 xg 是否比 xr 到其第 k 个最近真实邻居的距离更接近某个真实样本 xr。准确率是满足此条件的生成样本 xg 的比例。
准确率=Ng1i=1∑NgI(∃j:∥xg,i−xr,j∥2≤d(xr,j,NNk(xr,j)))
其中 I(⋅) 是指示函数。
- 召回率计算: 对于每个真实样本 xr,判断它是否被生成分布良好地表示。这是通过检查是否存在至少一个生成样本 xg,使得距离 ∥xr−xg∥2≤d(xg,NNk(xg)) 来完成的。召回率是满足此条件的真实样本 xr 的比例。
召回率=Nr1j=1∑NrI(∃i:∥xr,j−xg,i∥2≤d(xg,i,NNk(xg,i)))
准确率与召回率分数的解释
这些指标通常一起报告,经常在准确率-召回率图上进行可视化。这使得不同模型或训练检查点之间能够进行更详细的比较。
在准确率-召回率图上比较GAN模型。理想的模型位于右上角(高准确率,高召回率)。不同位置表示不同的权衡或失败模式。模型B显示高准确率但低召回率,这表明模式崩溃。模型C显示高召回率但低准确率,这表明尽管多样性良好,但样本质量不佳。模型D在两个轴上都表现不佳。模型A代表了更好的平衡。
模型主要沿着准确率轴改进时,它在生成真实样本方面会更好,即使它没有覆盖所有模式。沿着召回率轴的改进意味着它正在捕获更多数据多样性,最初可能以牺牲一些真实性为代价。目标是向右上角推进。
优点与考量
优点:
- 分离评估: 为逼真度(准确率)和多样性(召回率)提供独立分数,相比FID等单一分数指标,提供更明确的诊断。
- 模式崩溃检测: 低召回率与高准确率相结合是模式崩溃的有力指标。
- 质量评估: 低准确率表明样本真实性或伪影问题,无论多样性如何。
考量:
- 特征空间: 嵌入网络的选取会显著影响结果。使用Inception V3等标准预训练模型是为了可比性而常见的做法,但可能并非所有数据集或任务的最佳选择。
- 超参数 k: 最近邻数量 k 影响比较的局部性。原始论文建议 k=3 或 k=5,但可能需要进行敏感性分析。
- 计算成本: 在高维特征空间中为大型数据集计算k-NN距离相比FID可能计算成本较高。高效的k-NN算法(例如,使用Faiss等库)通常是必要的。
- 偏差: 与FID和IS一样,这些指标会继承预训练特征提取器中存在的任何偏差。
通过将准确率和召回率与其他指标(如FID)结合使用,您可以对GAN的性能有更全面的认识,明确需要改进的具体方面,包括生成样本的质量和多样性。