While optimizers and learning rate schedules steer the training process towards a minimum, regularization techniques are essential for ensuring that the learned model generalizes well to unseen data, preventing overfitting. Basic L1 or L2 weight decay, while useful, sometimes falls short, especially with complex architectures and large datasets. This section covers more advanced regularization strategies available within the PyTorch ecosystem.
Standard implementations of L2 regularization in adaptive optimizers like Adam often couple the weight decay term with the gradient calculation. This means the decay effect is influenced by the adaptive learning rates calculated by the optimizer (specifically, the squared gradient history, vt). This coupling can lead to suboptimal performance, particularly when large gradients occur, as the effective weight decay might become smaller than intended.
The AdamW optimizer, proposed by Loshchilov & Hutter (2019), addresses this by decoupling the weight decay from the gradient update. Instead of adding the decay term to the gradient, AdamW applies the decay directly to the weights after the main optimization step.
Conceptually, a standard Adam update with L2 regularization looks like: gt′=∇f(wt)+λwt wt+1=wt−η⋅AdamUpdate(gt′)
Whereas AdamW performs: gt=∇f(wt) wt+1′=wt−η⋅AdamUpdate(gt) wt+1=wt+1′−ηλwt
(Note: η here represents the learning rate potentially modified by Adam's adaptive components). The important difference is that the decay term ηλwt is applied separately and is not scaled by the adaptive learning rate terms within AdamUpdate
.
In PyTorch, using AdamW is straightforward:
import torch
import torch.optim as optim
# Assume model is defined
model = ...
# Use AdamW instead of Adam
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
# Training loop remains the same
# ...
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# ...
Choosing an appropriate weight_decay
value for AdamW often requires experimentation, but values between 1e-2
and 1e-1
are common starting points, typically higher than those used with standard Adam's coupled L2 regularization. AdamW has become a standard choice for training many modern architectures, particularly Transformers.
Classification models are typically trained using cross-entropy loss with one-hot encoded target labels (e.g., [0, 0, 1, 0]
). This encourages the model to produce extremely high confidence for the correct class and very low confidence for incorrect classes. While this drives the model towards fitting the training data perfectly, it can lead to overconfidence and poor generalization. The model might become overly sensitive to specific features in the training data, failing to capture broader patterns.
Label Smoothing Regularization (LSR) addresses this by slightly softening the target labels. Instead of requiring the model to predict exactly 1 for the correct class and 0 for others, it encourages the model to assign a small probability mass ϵ (epsilon) to the incorrect classes.
For a classification problem with K classes, if the original one-hot label for a sample is yk (1 for the true class, 0 otherwise), the smoothed label yk′ becomes:
yk′=yk(1−ϵ)+KϵHere, yk is 1 if k is the true class index and 0 otherwise. The term yk(1−ϵ) reduces the target probability for the correct class from 1.0 to 1−ϵ. The term ϵ/K distributes the remaining probability mass ϵ uniformly across all K classes (including the correct one, although its main effect is adding probability to the incorrect ones).
For example, with K=5 classes and ϵ=0.1, a one-hot label [0, 0, 1, 0, 0]
becomes:
[0.02, 0.02, 0.9, 0.02, 0.02]
This encourages the model's output logits (before the final softmax) for the correct class to be less extreme relative to the logits for incorrect classes. It acts as a regularizer by preventing the model from becoming excessively confident.
PyTorch's torch.nn.CrossEntropyLoss
directly supports label smoothing via the label_smoothing
argument:
import torch
import torch.nn as nn
# Example usage
num_classes = 10
smoothing_factor = 0.1
criterion = nn.CrossEntropyLoss(label_smoothing=smoothing_factor)
# Inside training loop
# outputs = model(inputs) # Shape: [batch_size, num_classes]
# targets = ... # Shape: [batch_size], contains class indices
# loss = criterion(outputs, targets)
# loss.backward()
Common values for ϵ are typically small, often in the range of 0.05 to 0.1.
Dropout is a widely used regularization technique that randomly sets individual neuron activations to zero during training. Stochastic Depth, also known as DropPath, offers a different approach, particularly effective in networks with residual connections (like ResNets or Transformers). Instead of dropping individual neurons, Stochastic Depth randomly drops entire residual blocks or layers during training.
Consider a residual block where the output is xl+1=xl+fl(xl). With Stochastic Depth, this transformation is modified during training:
xl+1=xl+bl⋅fl(xl)Here, bl is a Bernoulli random variable that is either 0 or 1. It takes the value 0 (dropping the block) with probability 1−pl and 1 (keeping the block) with probability pl. The probability pl is the survival probability of block l.
Often, the survival probability is decreased linearly for deeper layers in the network. For a network with L blocks, the survival probability for block l (where l ranges from 1 to L) might be set as:
pl=1−Ll(1−pL)Here, pL is the target survival probability for the final block. This scheme means that earlier layers (closer to the input) are more likely to be kept, while deeper layers are more likely to be dropped.
During inference, all blocks are kept, but their outputs might be scaled by their respective survival probabilities pl to compensate for the fact that they were present less often during training. However, many implementations handle this scaling implicitly or find it unnecessary.
Benefits of Stochastic Depth include:
Implementing Stochastic Depth typically involves using a specialized layer. Libraries like timm
(PyTorch Image Models) provide a convenient DropPath
module:
# Example using timm's DropPath
# Ensure you have timm installed: pip install timm
from timm.models.layers import DropPath
class MyResidualBlock(nn.Module):
def __init__(self, dim, drop_prob=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.linear1 = nn.Linear(dim, dim * 4)
self.activation = nn.GELU()
self.linear2 = nn.Linear(dim * 4, dim)
# Stochastic Depth layer
self.drop_path = DropPath(drop_prob) if drop_prob > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.norm1(x)
x = self.linear1(x)
x = self.activation(x)
x = self.linear2(x)
# Apply DropPath to the output of the residual function
x = shortcut + self.drop_path(x)
return x
# Example usage within a larger model
# drop_probabilities = torch.linspace(0, 0.1, num_layers) # Linear decay
# block = MyResidualBlock(dim=embed_dim, drop_prob=drop_probabilities[i].item())
The drop_prob
passed to DropPath
corresponds to 1−pl. Choosing the range for the drop probability (e.g., linearly increasing from 0 to 0.1 or 0.2) is another hyperparameter to tune.
These advanced regularization techniques, often used in combination, provide powerful tools for improving the robustness and generalization capabilities of deep learning models trained with PyTorch. Experimentation is usually needed to find the optimal combination and hyperparameters for a specific task and architecture.
© 2025 ApX Machine Learning