Okay, we've discussed the importance of generalization and the problems of underfitting and overfitting. But how do we actually see if our model is suffering from these issues during training? Simply looking at the final performance metric on a test set isn't enough; it doesn't tell us why the model performs the way it does or how it behaved during the learning process. This is where learning curves come in.
Learning curves are visual tools that plot a model's performance metric, typically the loss or accuracy, on both the training set and a separate validation set over time (usually epochs or iterations) or against the amount of training data used. By observing the trends and the gap between these two curves, we gain valuable insights into the model's learning dynamics and diagnose potential problems like high bias (underfitting) or high variance (overfitting).
How to Plot Learning Curves
The process generally involves these steps:
- Split Data: Divide your dataset into training, validation, and test sets. The learning curves are plotted using the training and validation sets. The test set is held out for final, unbiased evaluation.
- Train Iteratively: Train your model epoch by epoch.
- Evaluate Performance: At the end of each epoch (or sometimes more frequently), calculate the chosen performance metric (e.g., Mean Squared Error for regression, Cross-Entropy Loss or Accuracy for classification) on:
- The entire training set (or the mini-batch used in that iteration, though evaluating on the full training set gives a less noisy curve).
- The entire validation set.
- Plot: Create a plot where the x-axis represents the epoch number, and the y-axis represents the performance metric. Plot two lines: one for the training performance and one for the validation performance.
Interpreting the Curves
The shape and relationship between the training and validation curves reveal a lot about your model's fit. Let's look at common patterns:
1. Underfitting (High Bias)
If your model is too simple to capture the underlying patterns in the data, it will perform poorly on both the training and validation sets.
- Training Curve: The training loss will be high and might plateau relatively quickly, indicating the model isn't learning much even from the data it sees.
- Validation Curve: The validation loss will also be high, often very close to the training loss.
- Diagnosis: Both curves converge at a high error level. This suggests the model lacks the capacity (e.g., too few layers or neurons) to learn the task effectively. Adding more training data is unlikely to help significantly because the model fundamentally cannot represent the target function.
Both training and validation loss remain high, indicating the model isn't complex enough to capture the data's patterns.
2. Overfitting (High Variance)
When a model learns the training data too well, including its noise and specific quirks, it fails to generalize to new data.
- Training Curve: The training loss decreases steadily and reaches a very low value. The model fits the training data almost perfectly.
- Validation Curve: The validation loss initially decreases but then starts to level off or even increase after a certain point.
- Diagnosis: A significant gap emerges between the low training loss and the higher (and potentially increasing) validation loss. This indicates the model is memorizing the training examples instead of learning generalizable patterns.
Training loss becomes very low, while validation loss stagnates or increases, showing a large generalization gap.
3. Good Fit
The ideal scenario is a model that learns the relevant patterns without memorizing noise.
- Training Curve: The training loss decreases steadily and converges to a low value.
- Validation Curve: The validation loss also decreases steadily and converges to a low value, staying close to the training loss.
- Diagnosis: Both curves converge at a low error level with only a small gap between them. This indicates good generalization.
Both training and validation loss decrease and converge to low values with a small gap between them.
Using Learning Curves for Action
Diagnosing the problem is the first step. Learning curves guide your next actions:
- If Underfitting: Try increasing model complexity (more layers, more neurons), training for longer (if the curves haven't plateaued yet), changing the model architecture, or engineering better features.
- If Overfitting: Try getting more training data, applying regularization techniques (like L1/L2, Dropout, Batch Normalization - which we'll cover soon!), reducing model complexity, or using data augmentation. Early stopping (stopping training when validation performance starts degrading) is also a common strategy identified using learning curves.
Learning curves are an indispensable part of the deep learning practitioner's toolkit. They provide a window into the training process, helping you understand how your model is learning and guiding you towards building models that generalize effectively to unseen data. Keep them handy whenever you train a new model.