While metrics like Fréchet Inception Distance (FID) provide a single score summarizing GAN performance, they conflate two distinct aspects: the fidelity (quality or realism) of generated samples and the diversity (variety or coverage) of those samples relative to the real data distribution. A low FID score is desirable, but it doesn't tell us why a model might be performing poorly. Is it generating unrealistic images (low fidelity), or is it failing to capture the full variety of the data (low diversity, possibly mode collapse)?
To gain more specific insights, we can adapt the concepts of precision and recall, traditionally used in information retrieval and classification, to the evaluation of generative models. In the context of GANs, these metrics help disentangle the assessment of sample quality from the assessment of distributional coverage.
Understanding Precision and Recall for GANs
Imagine comparing the distribution of generated samples, Pg, with the distribution of real samples, Pr.
- Precision measures the fidelity of the generated samples. It answers the question: "Of the samples the generator produced, what fraction are realistic or plausible samples from the target distribution?" High precision indicates that the GAN generates few "fake-looking" or artifact-ridden samples. It relates to the quality aspect.
- Recall measures the diversity or coverage of the generated samples relative to the real data. It answers: "Of all the types of samples present in the real dataset, what fraction can the generator actually produce?" High recall indicates that the generator captures most of the variations and modes present in the real data distribution. It relates to the diversity aspect.
Consider these scenarios:
- High Precision, Low Recall: The GAN generates very realistic samples (high quality), but only covers a small subset of the true data variety (e.g., mode collapse).
- Low Precision, High Recall: The GAN generates a wide variety of samples covering most of the real data modes, but many of these samples are unrealistic or contain artifacts.
- Low Precision, Low Recall: The worst case. The GAN produces poor-quality samples and fails to capture the diversity of the real data.
- High Precision, High Recall: The ideal scenario. The GAN produces high-quality, realistic samples that cover the full diversity of the real data distribution.
A Practical Approach to Calculation
Calculating precision and recall for continuous, high-dimensional distributions requires a practical methodology. One common approach, proposed by Kynkäänniemi et al. (2019), involves analyzing the proximity of samples in a suitable feature space:
- Embedding: Embed a large number of real samples (Nr) and generated samples (Ng) into a feature space. This is often done using the activations from an intermediate layer of a pre-trained network, like Inception V3, similar to how FID is calculated. Let these feature vectors be xr∈Rd for real samples and xg∈Rd for generated samples.
- Nearest Neighbors: For each real sample xr, find its k-th nearest neighbor among all other real samples in the feature space based on Euclidean distance. Let the distance to this neighbor be d(xr,NNk(xr)). Similarly, for each generated sample xg, find the distance to its k-th nearest neighbor among all other generated samples, d(xg,NNk(xg)). The value of k (e.g., k=3 or k=5) is a hyperparameter. These distances define approximate local density estimates around each point.
- Precision Calculation: For each generated sample xg, determine if it falls within the "manifold" of the real data. This is done by checking if there is at least one real sample xr such that the distance ∥xg−xr∥2≤d(xr,NNk(xr)). In simpler terms, we check if the generated sample xg is closer to a real sample xr than xr is to its k-th nearest real neighbor. Precision is the fraction of generated samples xg for which this condition holds.
Precision=Ng1i=1∑NgI(∃j:∥xg,i−xr,j∥2≤d(xr,j,NNk(xr,j)))
where I(⋅) is the indicator function.
- Recall Calculation: For each real sample xr, determine if it is well-represented by the generated distribution. This is done by checking if there is at least one generated sample xg such that the distance ∥xr−xg∥2≤d(xg,NNk(xg)). Recall is the fraction of real samples xr for which this condition holds.
Recall=Nr1j=1∑NrI(∃i:∥xr,j−xg,i∥2≤d(xg,i,NNk(xg,i)))
Interpreting Precision and Recall Scores
These metrics are typically reported together, often visualized on a Precision-Recall plot. This allows for a more detailed comparison between different models or training checkpoints.
Comparing GAN models on a Precision-Recall plot. The ideal model resides in the top-right corner (High Precision, High Recall). Different locations indicate different trade-offs or failure modes. Model B shows high precision but low recall, suggesting mode collapse. Model C shows high recall but low precision, suggesting poor sample quality despite diversity. Model D performs poorly on both axes. Model A represents a better balance.
A model improving primarily along the precision axis is getting better at generating realistic samples, even if it doesn't cover all modes. Improvement along the recall axis means it's capturing more of the data's diversity, possibly at the expense of some realism initially. The goal is to push towards the top-right corner.
Advantages and Considerations
Advantages:
- Disentangled Evaluation: Provides separate scores for fidelity (precision) and diversity (recall), offering clearer diagnostics than single-score metrics like FID.
- Mode Collapse Detection: Low recall combined with high precision is a strong indicator of mode collapse.
- Quality Assessment: Low precision suggests problems with sample realism or artifacts, regardless of diversity.
Considerations:
- Feature Space: The choice of embedding network significantly impacts the results. Using standard pre-trained models like Inception V3 is common for comparability, but may not be optimal for all datasets or tasks.
- Hyperparameter k: The number of nearest neighbors k influences the locality of the comparison. The original paper suggests k=3 or k=5, but sensitivity analysis might be needed.
- Computational Cost: Calculating k-NN distances for large datasets in high-dimensional feature spaces can be computationally expensive compared to FID. Efficient k-NN algorithms (e.g., using libraries like Faiss) are often necessary.
- Bias: Like FID and IS, these metrics inherit any biases present in the pre-trained feature extractor.
By using precision and recall alongside other metrics like FID, you can develop a more comprehensive understanding of your GAN's performance, pinpointing specific areas for improvement in terms of both the quality and the diversity of the generated samples.