趋近智
Fréchet Inception Distance (FID) 是评估生成模型时广为采用的指标之一。FID通过比较从真实图像中提取的特征分布与从合成图像中提取的特征分布,来评估生成图像的质量和多样性。较低的FID分数表示生成图像的分布更接近真实图像的分布,说明其具有更高的保真度和多样性。
此计算依赖于使用预训练的Inception-v3网络提取的特征表示。我们将使用PyTorch和torchvision进行本次实践。
计算FID分数涉及以下主要步骤:
FID公式如下所示:
FID(x,g)=∣∣μx−μg∣∣22+迹(Σx+Σg−2(ΣxΣg)1/2)这里,∣∣μx−μg∣∣22 是均值向量之间的欧几里得距离的平方,而 迹 表示矩阵的迹。项 (ΣxΣg)1/2 代表协方差矩阵乘积的矩阵平方根。
我们来逐步实现FID计算。请确保您已安装 torch、torchvision、numpy 和 scipy。
我们需要在ImageNet上预训练的Inception-v3模型。torchvision 轻松提供了这一点。我们将对其进行微小修改,使其从中间层输出特征,而不是分类logit。FIDInceptionA 块(输出维度2048)常被使用。
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.nn.functional import adaptive_avg_pool2d
import numpy as np
from scipy.linalg import sqrtm
# 加载预训练的Inception v3模型
inception_model = models.inception_v3(pretrained=True, transform_input=False)
# 修改模型,使其从所需层输出特征
# 我们将使用最后一个池化层(FIDInceptionA)的输出
class InceptionV3FeatureExtractor(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = models.inception_v3(pretrained=True, aux_logits=False)
# 移除最终的全连接层
self.model.fc = torch.nn.Identity()
# 确保输入转换符合Inception V3的要求
self.transform = transforms.Compose([
transforms.Resize(299),
transforms.CenterCrop(299),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def forward(self, x):
# 应用转换
# 注意:如果输入已是张量批次,请单独应用转换
# 或确保DataLoader处理转换。
# 为简单起见,假设x是PIL图像的批次或需要转换。
# 如果x已经是张量批次[N, C, H, W],如果之前已完成,则在此处跳过转换。
# 处理可能的输入类型(PIL图像与张量)
if not isinstance(x, torch.Tensor):
# 假设x是PIL图像的列表/批次
x = torch.stack([self.transform(img) for img in x])
elif x.shape[2] != 299 or x.shape[3] != 299:
# 如果是张量但尺寸不正确,应用转换(根据需要调整)
# 这部分可能需要根据输入管道进行仔细处理
x = torch.stack([transforms.ToPILImage()(img) for img in x]) # 转换为PIL
x = torch.stack([self.transform(img) for img in x]) # 应用完整转换
# 将输入通过Inception模型
# 确保模型处于评估模式
self.model.eval()
with torch.no_grad():
features = self.model(x)
# 输出可能需要根据具体的InceptionV3用法进行重塑
# 原始FID实现可能使用特定的池化层。
# 对于`self.model.fc = torch.nn.Identity()`,输出应直接是[N, 2048]。
# 如果在最终池化之前使用特征,可能需要自适应池化:
# features = adaptive_avg_pool2d(features, (1, 1))
# features = features.view(features.shape[0], -1)
return features
# 实例化特征提取器
fid_extractor = InceptionV3FeatureExtractor()
# 如果GPU可用,移动到GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fid_extractor.to(device)
fid_extractor.eval() # 设置为评估模式
print("InceptionV3模型已加载用于特征提取。")
重要提示: Inception-v3模型要求输入图像大小为 299×299,并进行特定归一化。请确保您的数据加载管道正确预处理真实图像和生成图像。示例代码包含一个基本转换,但您可能需要根据数据存储和加载方式(例如,使用PyTorch DataLoader)对其进行调整。
现在,编写一个函数来遍历您的数据集(真实和生成),使用准备好的模型提取特征,并将其收集起来。
def get_features(dataloader, model, device, max_samples=None):
features = []
count = 0
for batch in dataloader:
# 假设dataloader生成图像张量批次
# 将批次移动到合适的设备
if isinstance(batch, (list, tuple)): # 处理 (图像, 标签) 等情况
images = batch[0].to(device)
else:
images = batch.to(device) # 假设批次仅包含图像
# 如果DataLoader中未进行转换,则在此处应用
# 确保图像已为模型正确格式化(例如,大小299x299,已归一化)
# 上述InceptionV3FeatureExtractor包含一个基本转换,
# 但为了效率,通常最好在DataLoader中处理。
# 如果在模型内部使用转换,请传入PIL图像或原始张量。
batch_features = model(images).detach().cpu().numpy()
features.append(batch_features)
count += images.shape[0]
if max_samples is not None and count >= max_samples:
break
features = np.concatenate(features, axis=0)
if max_samples is not None:
features = features[:max_samples]
return features
# 示例用法(请替换为您的实际DataLoader)
# 假设real_dataloader和fake_dataloader是PyTorch DataLoader
# 它们生成批次,包含正确预处理的尺寸为(N, 3, 299, 299)的图像张量
# print("正在提取真实图像的特征...")
# real_features = get_features(real_dataloader, fid_extractor, device, max_samples=10000)
# print(f"已提取 {real_features.shape[0]} 个真实特征。")
# print("正在提取生成图像的特征...")
# fake_features = get_features(fake_dataloader, fid_extractor, device, max_samples=10000)
# print(f"已提取 {fake_features.shape[0]} 个生成特征。")
# 为演示目的,我们创建虚拟特征:
feature_dim = 2048
num_samples = 10000
print(f"正在生成虚拟特征({num_samples}个样本,维度={feature_dim})。..")
real_features = np.random.rand(num_samples, feature_dim)
# 使生成特征略有不同,以便FID不为零
fake_features = np.random.rand(num_samples, feature_dim) * 1.1 + 0.1
print("虚拟特征已生成。")
为了获得可靠的FID分数,建议从真实分布和生成分布中都使用大量的样本,通常是10,000或50,000个。
提取特征后,计算均值和协方差矩阵,然后将它们代入FID公式。
def calculate_fid(features1, features2):
# 计算均值和协方差统计量
mu1, sigma1 = np.mean(features1, axis=0), np.cov(features1, rowvar=False)
mu2, sigma2 = np.mean(features2, axis=0), np.cov(features2, rowvar=False)
# 计算均值之间的平方差
ssdiff = np.sum((mu1 - mu2)**2.0)
# 计算协方差乘积的平方根
# 如有需要,添加一个小的epsilon以提高数值稳定性
eps = 1e-6
covmean_sqrt, _ = sqrtm(sigma1.dot(sigma2), disp=False)
# 检查并修正矩阵平方根中的虚数
if np.iscomplexobj(covmean_sqrt):
# print("警告:sqrtm中遇到复数。使用实部。")
covmean_sqrt = covmean_sqrt.real
# 计算分数
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean_sqrt)
# 处理由于数值不稳定导致的潜在负值
if fid < 0:
# print(f"警告:检测到负FID({fid})。截断为0。")
fid = 0.0
return fid
# 使用提取的(或虚拟的)特征计算FID
print("正在计算FID分数...")
fid_score = calculate_fid(real_features, fake_features)
print(f"计算得到的FID分数:{fid_score:.4f}")
# 虚拟FID示例:会根据随机数而变化,可能较大。
# 实际的FID计算可能得到5.0、10.0、50.0等值。越低越好。
scipy.linalg.sqrtm 函数计算矩阵的平方根。请注意处理因数值精度问题可能出现的复数;在这种情况下,我们取实部。由于浮点误差,特别是当分布非常接近时,也可能出现小的负FID值;这些值通常被截断为零。
本次实践提供了计算您自己的生成模型的FID分数的工具。使用FID等指标进行一致且严谨的评估,对于跟踪进度和比较不同GAN与扩散模型体系结构及训练策略的性能来说非常重要。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造