In the preceding sections, we discussed the concepts of generalization, overfitting, and underfitting. You learned how learning curves can serve as a diagnostic tool, showing the gap between training performance and validation performance. Now, let's put this knowledge into practice by generating a scenario where overfitting is likely and visualizing it directly. This exercise will solidify your understanding of how to spot overfitting using both performance metrics and the model's behavior.
We'll use a simple regression problem: fitting a model to data generated from a known function with added noise. By comparing a simple model to a complex one, we can observe how the complex model might learn the noise in the training data, leading to poor performance on unseen data.
First, ensure you have your Python environment ready with PyTorch, NumPy, and Matplotlib (or another plotting library like Plotly or Seaborn) installed, as outlined in the "Setting up the Development Environment" section.
We'll create a small synthetic dataset based on a simple underlying function, like a sine wave or a low-degree polynomial, and add some random noise. This noise mimics the imperfections and random variations present in real-world data.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
# Generate synthetic data
torch.manual_seed(42) # for reproducibility
n_samples = 30
X = torch.linspace(-np.pi, np.pi, n_samples).unsqueeze(1)
# True function: sin(x) + small linear trend
y_true = torch.sin(X) + 0.1 * X
# Add noise
y = y_true + torch.randn(X.size()) * 0.2
# Split into training and validation sets (simple split for illustration)
n_train = 20
X_train, y_train = X[:n_train], y[:n_train]
X_val, y_val = X[n_train:], y[n_train:]
# Function to plot data
def plot_data(X_train, y_train, X_val, y_val, X_full=None, y_pred=None, model_name=None):
plt.figure(figsize=(8, 5))
plt.scatter(X_train.numpy(), y_train.numpy(), label='Training Data', c='#1f77b4', s=50, alpha=0.7)
plt.scatter(X_val.numpy(), y_val.numpy(), label='Validation Data', c='#ff7f0e', s=50, alpha=0.7, marker='x')
if X_full is not None and y_pred is not None:
plt.plot(X_full.numpy(), y_pred.detach().numpy(), label=f'{model_name} Prediction', c='#2ca02c', linewidth=2)
plt.xlabel("Feature (X)")
plt.ylabel("Target (y)")
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.title("Synthetic Data and Model Fit")
plt.show()
# Plot the generated data
# plot_data(X_train, y_train, X_val, y_val) # Uncomment to see the raw data
Now, let's define two neural network models using PyTorch. One will be relatively simple, and the other significantly more complex (with more layers or neurons), making it prone to overfitting on our small dataset.
# Simple Model (Potentially Underfitting or Good Fit)
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
# A simple linear layer might underfit, let's try a small MLP
self.layer1 = nn.Linear(1, 10)
self.activation = nn.ReLU()
self.output_layer = nn.Linear(10, 1)
def forward(self, x):
x = self.activation(self.layer1(x))
x = self.output_layer(x)
return x
# Complex Model (Prone to Overfitting)
class ComplexModel(nn.Module):
def __init__(self):
super().__init__()
# More layers and neurons
self.layer1 = nn.Linear(1, 128)
self.activation1 = nn.ReLU()
self.layer2 = nn.Linear(128, 128)
self.activation2 = nn.ReLU()
self.layer3 = nn.Linear(128, 64)
self.activation3 = nn.ReLU()
self.output_layer = nn.Linear(64, 1)
def forward(self, x):
x = self.activation1(self.layer1(x))
x = self.activation2(self.layer2(x))
x = self.activation3(self.layer3(x))
x = self.output_layer(x)
return x
simple_model = SimpleModel()
complex_model = ComplexModel()
We'll train both models using a standard training loop. We'll use Mean Squared Error (MSE) as the loss function and the Adam optimizer (which we'll cover in detail later, but it's a common default). We will track the training loss and validation loss for each epoch.
def train_model(model, X_train, y_train, X_val, y_val, epochs=2000, lr=0.005):
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
train_losses = []
val_losses = []
for epoch in range(epochs):
model.train() # Set model to training mode
# Forward pass
outputs = model(X_train)
loss = criterion(outputs, y_train)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Evaluate on validation set
model.eval() # Set model to evaluation mode
with torch.no_grad():
val_outputs = model(X_val)
val_loss = criterion(val_outputs, y_val)
train_losses.append(loss.item())
val_losses.append(val_loss.item())
if (epoch + 1) % 200 == 0:
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}')
return train_losses, val_losses
print("Training Simple Model...")
simple_train_losses, simple_val_losses = train_model(simple_model, X_train, y_train, X_val, y_val)
print("\nTraining Complex Model...")
complex_train_losses, complex_val_losses = train_model(complex_model, X_train, y_train, X_val, y_val)
Learning curves plot the training and validation loss over epochs. They are a primary tool for diagnosing fitting issues.
Let's plot the learning curves for both models.
Comparison of learning curves for the simple and complex models. Note the logarithmic scale on the y-axis. The complex model shows a clear divergence between training loss (decreasing) and validation loss (increasing), indicating overfitting. The simple model's losses converge more closely. (Note: Actual loss values depend on the specific run; these are illustrative.)
Observe the learning curves. The complex model's training loss likely drops very low, indicating it fits the training data well. However, its validation loss, after decreasing initially, probably starts to increase or levels off at a much higher value than the training loss. This widening gap is the classic sign of overfitting. The simple model might show less of a gap, or both losses might plateau, suggesting it's not complex enough to overfit (or perhaps even slightly underfits).
Another way to visualize overfitting, especially in regression, is to plot the model's predictions against the actual data points. An overfit model will often show excessive fluctuations, trying to pass through every training point, including the noise.
# Plot predictions
simple_model.eval()
complex_model.eval()
with torch.no_grad():
# Use the full range X for plotting the curve
y_pred_simple = simple_model(X)
y_pred_complex = complex_model(X)
# Plot Simple Model Fit
plot_data(X_train, y_train, X_val, y_val, X, y_pred_simple, "Simple Model")
# Plot Complex Model Fit
plot_data(X_train, y_train, X_val, y_val, X, y_pred_complex, "Complex Model")
# Optional: Plot true underlying function
plt.figure(figsize=(8, 5))
plt.scatter(X_train.numpy(), y_train.numpy(), label='Training Data', c='#1f77b4', s=50, alpha=0.7)
plt.scatter(X_val.numpy(), y_val.numpy(), label='Validation Data', c='#ff7f0e', s=50, alpha=0.7, marker='x')
plt.plot(X.numpy(), y_true.numpy(), label='True Function', c='black', linestyle=':', linewidth=2)
plt.plot(X.numpy(), y_pred_complex.detach().numpy(), label='Complex Model Prediction', c='#d62728', linewidth=2) # Use red for overfitting
plt.xlabel("Feature (X)")
plt.ylabel("Target (y)")
plt.legend()
plt.ylim(y.min()-0.5, y.max()+0.5) # Adjust y-limits for better view
plt.grid(True, linestyle='--', alpha=0.5)
plt.title("Complex Model vs. True Function")
plt.show()
When you examine the plots:
This practice exercise demonstrated two visual methods for identifying overfitting:
Recognizing overfitting is the first step. The following chapters will introduce techniques like regularization (L1/L2, Dropout, Batch Normalization) and sophisticated optimization algorithms designed to combat this problem and help our models generalize better to unseen data.
© 2025 ApX Machine Learning