The Fréchet Inception Distance (FID) is one of the most widely used metrics for assessing generative models. FID evaluates the quality and diversity of generated images by comparing the distribution of features extracted from real images with those extracted from synthetic images. A lower FID score suggests that the distribution of generated images is closer to the distribution of real images, indicating higher fidelity and diversity.The calculation relies on feature representations extracted using a pre-trained Inception-v3 network. We'll use PyTorch and torchvision for this practical exercise.Core Steps in FID CalculationCalculating the FID score involves these main steps:Feature Extraction: Process both a set of real images and a set of generated images through a pre-trained Inception-v3 network (typically up to the penultimate layer, before the final classification). This yields high-dimensional feature vectors for each image.Distribution Modeling: Assume the extracted features for both the real set ($X$) and the generated set ($G$) follow multivariate Gaussian distributions. Calculate the mean vector ($\mu_x$, $\mu_g$) and the covariance matrix ($\Sigma_x$, $\Sigma_g$) for each set of features.Fréchet Distance Calculation: Compute the Fréchet distance between these two Gaussian distributions ($N(\mu_x, \Sigma_x)$ and $N(\mu_g, \Sigma_g)$).The FID formula is given by: $$ FID(x, g) = ||\mu_x - \mu_g||^2_2 + \text{Tr}(\Sigma_x + \Sigma_g - 2(\Sigma_x \Sigma_g)^{1/2}) $$Here, $ ||\mu_x - \mu_g||^2_2 $ is the squared Euclidean distance between the mean vectors, and $ \text{Tr} $ denotes the trace of a matrix. The term $ (\Sigma_x \Sigma_g)^{1/2} $ represents the matrix square root of the product of the covariance matrices.Implementation using PyTorchLet's implement the FID calculation step-by-step. Ensure you have torch, torchvision, numpy, and scipy installed.1. Loading the Inception-v3 ModelWe need the Inception-v3 model pre-trained on ImageNet. torchvision provides this easily. We'll modify it slightly to output features from an intermediate layer instead of classification logits. The FIDInceptionA block (output dimension 2048) is commonly used.import torch import torchvision.models as models import torchvision.transforms as transforms from torch.nn.functional import adaptive_avg_pool2d import numpy as np from scipy.linalg import sqrtm # Load the pretrained Inception v3 model inception_model = models.inception_v3(pretrained=True, transform_input=False) # Modify the model to output features from the desired layer # We'll use the output of the last pooling layer (FIDInceptionA) class InceptionV3FeatureExtractor(torch.nn.Module): def __init__(self): super().__init__() self.model = models.inception_v3(pretrained=True, aux_logits=False) # Remove the final fully connected layer self.model.fc = torch.nn.Identity() # Ensure input transformation matches Inception V3 requirements self.transform = transforms.Compose([ transforms.Resize(299), transforms.CenterCrop(299), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def forward(self, x): # Apply the transform # Note: If input is already a batch of tensors, apply transform individually # or ensure the DataLoader handles transformations. # For simplicity, assume x is a batch of PIL Images or requires transform. # If x is already a tensor batch [N, C, H, W], skip transform here if done earlier. # Handle potential input types (PIL images vs tensors) if not isinstance(x, torch.Tensor): # Assuming x is a list/batch of PIL Images x = torch.stack([self.transform(img) for img in x]) elif x.shape[2] != 299 or x.shape[3] != 299: # If tensor but wrong size, apply transforms (adjust as needed) # This part might need careful handling depending on input pipeline x = torch.stack([transforms.ToPILImage()(img) for img in x]) # Convert to PIL x = torch.stack([self.transform(img) for img in x]) # Apply full transform # Pass input through the Inception model # Make sure the model is in eval mode self.model.eval() with torch.no_grad(): features = self.model(x) # The output might need reshaping depending on the exact InceptionV3 usage # Original FID implementations might use specific pooling layers. # For `self.model.fc = torch.nn.Identity()`, the output should be [N, 2048] directly. # If using features before the final pool, adaptive pooling might be needed: # features = adaptive_avg_pool2d(features, (1, 1)) # features = features.view(features.shape[0], -1) return features # Instantiate the feature extractor fid_extractor = InceptionV3FeatureExtractor() # Move to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") fid_extractor.to(device) fid_extractor.eval() # Set to evaluation mode print("InceptionV3 model loaded for feature extraction.")Important: The Inception-v3 model expects input images of size $299 \times 299$ with specific normalization. Ensure your data loading pipeline correctly preprocesses both real and generated images. The example code includes a basic transform, but you might need to adapt it based on how your data is stored and loaded (e.g., using a PyTorch DataLoader).2. Extracting FeaturesNow, write a function to iterate through your datasets (real and generated), extract features using the prepared model, and collect them.def get_features(dataloader, model, device, max_samples=None): features = [] count = 0 for batch in dataloader: # Assuming dataloader yields batches of image tensors # Move batch to the appropriate device if isinstance(batch, (list, tuple)): # Handle cases like (images, labels) images = batch[0].to(device) else: images = batch.to(device) # Assuming batch is just images # Apply transformations if not done in DataLoader # Ensure images are correctly formatted for the model (e.g., size 299x299, normalized) # The InceptionV3FeatureExtractor above includes a basic transform, # but it's often better handled in the DataLoader for efficiency. # If using the transform inside the model, pass PIL images or raw tensors. batch_features = model(images).detach().cpu().numpy() features.append(batch_features) count += images.shape[0] if max_samples is not None and count >= max_samples: break features = np.concatenate(features, axis=0) if max_samples is not None: features = features[:max_samples] return features # Example usage (replace with your actual DataLoaders) # Assume real_dataloader and fake_dataloader are PyTorch DataLoaders # yielding batches of correctly preprocessed image tensors of size (N, 3, 299, 299) # print("Extracting features for real images...") # real_features = get_features(real_dataloader, fid_extractor, device, max_samples=10000) # print(f"Extracted {real_features.shape[0]} real features.") # print("Extracting features for generated images...") # fake_features = get_features(fake_dataloader, fid_extractor, device, max_samples=10000) # print(f"Extracted {fake_features.shape[0]} fake features.") # For demonstration, let's create dummy features: feature_dim = 2048 num_samples = 10000 print(f"Generating dummy features ({num_samples} samples, dim={feature_dim})...") real_features = np.random.rand(num_samples, feature_dim) # Make fake features slightly different for a non-zero FID fake_features = np.random.rand(num_samples, feature_dim) * 1.1 + 0.1 print("Dummy features generated.")For reliable FID scores, it's recommended to use a substantial number of samples, typically 10,000 or 50,000, from both the real and generated distributions.3. Calculating Statistics and FIDWith the features extracted, calculate the means and covariance matrices, then plug them into the FID formula.def calculate_fid(features1, features2): # Calculate mean and covariance statistics mu1, sigma1 = np.mean(features1, axis=0), np.cov(features1, rowvar=False) mu2, sigma2 = np.mean(features2, axis=0), np.cov(features2, rowvar=False) # Calculate squared difference between means ssdiff = np.sum((mu1 - mu2)**2.0) # Calculate sqrt of product of covariances # Adding a small epsilon for numerical stability if needed eps = 1e-6 covmean_sqrt, _ = sqrtm(sigma1.dot(sigma2), disp=False) # Check and correct imaginary numbers from matrix sqrt if np.iscomplexobj(covmean_sqrt): # print("Warning: Complex number encountered in sqrtm. Using real part.") covmean_sqrt = covmean_sqrt.real # Calculate score fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean_sqrt) # Handle potential negative values due to numerical instability if fid < 0: # print(f"Warning: Negative FID ({fid}) detected. Clipping to 0.") fid = 0.0 return fid # Calculate FID using the extracted (or dummy) features print("Calculating FID score...") fid_score = calculate_fid(real_features, fake_features) print(f"Calculated FID Score: {fid_score:.4f}") # Example dummy FID: Will vary based on random numbers, likely large. # A real FID calculation might yield values like 5.0, 10.0, 50.0 etc. Lower is better.The scipy.linalg.sqrtm function computes the matrix square root. Note the handling of potential complex numbers that might arise due to numerical precision issues; we take the real part in such cases. Small negative FID values can also occur due to floating-point inaccuracies, especially when distributions are very close; these are typically clipped to zero.Interpreting the FID ScoreLower is Better: A perfect score is 0, meaning the distributions of real and generated features are identical. Lower scores indicate better alignment.Relative Comparison: FID is most useful for comparing different generative models or different training stages of the same model on the same real dataset. Absolute FID values depend heavily on the dataset and the specific implementation details.Sensitivity: FID is sensitive to the number of samples used. Always use the same number of real and generated samples for fair comparison, and report this number alongside the score (e.g., FID-10k, FID-50k).Limitations: While powerful, FID doesn't capture all aspects of image quality. It primarily measures distributional similarity based on Inception features. It can be fooled by models that overfit to the training set or generate artifacts that don't significantly alter the feature statistics. Always complement FID with qualitative visual inspection and potentially other metrics.This practical exercise provides the tools to compute FID scores for your own generative models. Consistent and rigorous evaluation using metrics like FID is essential for tracking progress and comparing the performance of different GAN and diffusion model architectures and training strategies.