Having explored the theoretical underpinnings of various evaluation metrics, we now turn to the practical implementation of one of the most widely used metrics for assessing generative models: the Fréchet Inception Distance (FID). FID evaluates the quality and diversity of generated images by comparing the distribution of features extracted from real images with those extracted from synthetic images. A lower FID score suggests that the distribution of generated images is closer to the distribution of real images, indicating higher fidelity and diversity.
The calculation relies on feature representations extracted using a pre-trained Inception-v3 network. We'll use PyTorch and torchvision
for this practical exercise.
Calculating the FID score involves these main steps:
The FID formula is given by:
FID(x,g)=∣∣μx−μg∣∣22+Tr(Σx+Σg−2(ΣxΣg)1/2)Here, ∣∣μx−μg∣∣22 is the squared Euclidean distance between the mean vectors, and Tr denotes the trace of a matrix. The term (ΣxΣg)1/2 represents the matrix square root of the product of the covariance matrices.
Let's implement the FID calculation step-by-step. Ensure you have torch
, torchvision
, numpy
, and scipy
installed.
We need the Inception-v3 model pre-trained on ImageNet. torchvision
provides this easily. We'll modify it slightly to output features from an intermediate layer instead of classification logits. The FIDInceptionA
block (output dimension 2048) is commonly used.
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.nn.functional import adaptive_avg_pool2d
import numpy as np
from scipy.linalg import sqrtm
# Load the pretrained Inception v3 model
inception_model = models.inception_v3(pretrained=True, transform_input=False)
# Modify the model to output features from the desired layer
# We'll use the output of the last pooling layer (FIDInceptionA)
class InceptionV3FeatureExtractor(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = models.inception_v3(pretrained=True, aux_logits=False)
# Remove the final fully connected layer
self.model.fc = torch.nn.Identity()
# Ensure input transformation matches Inception V3 requirements
self.transform = transforms.Compose([
transforms.Resize(299),
transforms.CenterCrop(299),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def forward(self, x):
# Apply the transform
# Note: If input is already a batch of tensors, apply transform individually
# or ensure the DataLoader handles transformations.
# For simplicity, assume x is a batch of PIL Images or requires transform.
# If x is already a tensor batch [N, C, H, W], skip transform here if done earlier.
# Handle potential input types (PIL images vs tensors)
if not isinstance(x, torch.Tensor):
# Assuming x is a list/batch of PIL Images
x = torch.stack([self.transform(img) for img in x])
elif x.shape[2] != 299 or x.shape[3] != 299:
# If tensor but wrong size, apply transforms (adjust as needed)
# This part might need careful handling depending on input pipeline
x = torch.stack([transforms.ToPILImage()(img) for img in x]) # Convert to PIL
x = torch.stack([self.transform(img) for img in x]) # Apply full transform
# Pass input through the Inception model
# Make sure the model is in eval mode
self.model.eval()
with torch.no_grad():
features = self.model(x)
# The output might need reshaping depending on the exact InceptionV3 usage
# Original FID implementations might use specific pooling layers.
# For `self.model.fc = torch.nn.Identity()`, the output should be [N, 2048] directly.
# If using features before the final pool, adaptive pooling might be needed:
# features = adaptive_avg_pool2d(features, (1, 1))
# features = features.view(features.shape[0], -1)
return features
# Instantiate the feature extractor
fid_extractor = InceptionV3FeatureExtractor()
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fid_extractor.to(device)
fid_extractor.eval() # Set to evaluation mode
print("InceptionV3 model loaded for feature extraction.")
Important: The Inception-v3 model expects input images of size 299×299 with specific normalization. Ensure your data loading pipeline correctly preprocesses both real and generated images. The example code includes a basic transform, but you might need to adapt it based on how your data is stored and loaded (e.g., using a PyTorch DataLoader
).
Now, write a function to iterate through your datasets (real and generated), extract features using the prepared model, and collect them.
def get_features(dataloader, model, device, max_samples=None):
features = []
count = 0
for batch in dataloader:
# Assuming dataloader yields batches of image tensors
# Move batch to the appropriate device
if isinstance(batch, (list, tuple)): # Handle cases like (images, labels)
images = batch[0].to(device)
else:
images = batch.to(device) # Assuming batch is just images
# Apply transformations if not done in DataLoader
# Ensure images are correctly formatted for the model (e.g., size 299x299, normalized)
# The InceptionV3FeatureExtractor above includes a basic transform,
# but it's often better handled in the DataLoader for efficiency.
# If using the transform inside the model, pass PIL images or raw tensors.
batch_features = model(images).detach().cpu().numpy()
features.append(batch_features)
count += images.shape[0]
if max_samples is not None and count >= max_samples:
break
features = np.concatenate(features, axis=0)
if max_samples is not None:
features = features[:max_samples]
return features
# Example usage (replace with your actual DataLoaders)
# Assume real_dataloader and fake_dataloader are PyTorch DataLoaders
# yielding batches of correctly preprocessed image tensors of size (N, 3, 299, 299)
# print("Extracting features for real images...")
# real_features = get_features(real_dataloader, fid_extractor, device, max_samples=10000)
# print(f"Extracted {real_features.shape[0]} real features.")
# print("Extracting features for generated images...")
# fake_features = get_features(fake_dataloader, fid_extractor, device, max_samples=10000)
# print(f"Extracted {fake_features.shape[0]} fake features.")
# For demonstration, let's create dummy features:
feature_dim = 2048
num_samples = 10000
print(f"Generating dummy features ({num_samples} samples, dim={feature_dim})...")
real_features = np.random.rand(num_samples, feature_dim)
# Make fake features slightly different for a non-zero FID
fake_features = np.random.rand(num_samples, feature_dim) * 1.1 + 0.1
print("Dummy features generated.")
For reliable FID scores, it's recommended to use a substantial number of samples, typically 10,000 or 50,000, from both the real and generated distributions.
With the features extracted, calculate the means and covariance matrices, then plug them into the FID formula.
def calculate_fid(features1, features2):
# Calculate mean and covariance statistics
mu1, sigma1 = np.mean(features1, axis=0), np.cov(features1, rowvar=False)
mu2, sigma2 = np.mean(features2, axis=0), np.cov(features2, rowvar=False)
# Calculate squared difference between means
ssdiff = np.sum((mu1 - mu2)**2.0)
# Calculate sqrt of product of covariances
# Adding a small epsilon for numerical stability if needed
eps = 1e-6
covmean_sqrt, _ = sqrtm(sigma1.dot(sigma2), disp=False)
# Check and correct imaginary numbers from matrix sqrt
if np.iscomplexobj(covmean_sqrt):
# print("Warning: Complex number encountered in sqrtm. Using real part.")
covmean_sqrt = covmean_sqrt.real
# Calculate score
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean_sqrt)
# Handle potential negative values due to numerical instability
if fid < 0:
# print(f"Warning: Negative FID ({fid}) detected. Clipping to 0.")
fid = 0.0
return fid
# Calculate FID using the extracted (or dummy) features
print("Calculating FID score...")
fid_score = calculate_fid(real_features, fake_features)
print(f"Calculated FID Score: {fid_score:.4f}")
# Example dummy FID: Will vary based on random numbers, likely large.
# A real FID calculation might yield values like 5.0, 10.0, 50.0 etc. Lower is better.
The scipy.linalg.sqrtm
function computes the matrix square root. Note the handling of potential complex numbers that might arise due to numerical precision issues; we take the real part in such cases. Small negative FID values can also occur due to floating-point inaccuracies, especially when distributions are very close; these are typically clipped to zero.
This practical exercise provides the tools to compute FID scores for your own generative models. Consistent and rigorous evaluation using metrics like FID is essential for tracking progress and comparing the performance of different GAN and diffusion model architectures and training strategies.
© 2025 ApX Machine Learning