By Sam G. on Apr 15, 2025
Transformer models possess a high capacity for learning complex patterns due to their architecture, but this capacity also makes them susceptible to overfitting. This is frequently observed when adapting large pre-trained models to smaller, specific datasets. Fortunately, several straightforward and effective regularization methods have been developed to mitigate this.
Dropout is a frequently used regularization technique in neural networks, including Transformers. Its main function is to prevent neurons from becoming overly reliant on each other (co-adaptation) by randomly setting a fraction of neuron outputs to zero during each training update.
In the architecture described in the original Transformer paper (Vaswani et al.), dropout was incorporated in multiple locations:
Selecting the dropout rate often involves experimentation. Rates between 0.1 and 0.3 are common starting points, varying based on model size and the dataset characteristics. Lower rates like 0.1 or even 0.05 might be preferred for large models to prevent excessive information loss (underfitting).
import torch
import torch.nn as nn
class TransformerBlock(nn.Module):
def __init__(self, hidden_size, num_heads=8, dropout_rate=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(hidden_size, num_heads)
self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
self.ffn = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.ReLU(),
nn.Linear(hidden_size * 4, hidden_size)
)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
attn_output, _ = self.attention(x, x, x)
x = self.norm1(x + self.dropout(attn_output))
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_output))
return x
Weight decay, also known as L2 regularization, adds a penalty to the loss function based on the squared magnitude of the model's weights. This discourages the model from learning overly large weight values, which is often associated with overfitting.
import torch.optim as optim
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if param.requires_grad:
if 'weight' in name and 'norm' not in name:
decay_params.append(param)
else:
no_decay_params.append(param)
optimizer_grouped_parameters = [
{'params': decay_params, 'weight_decay': 0.01},
{'params': no_decay_params, 'weight_decay': 0.0}
]
optimizer = optim.AdamW(optimizer_grouped_parameters, lr=5e-5)
Label smoothing is a technique that adjusts the target labels used during training. This prevents the model from becoming overly certain about its predictions on the training data, which can improve generalization.
import torch
def smooth_labels(labels, smoothing=0.1, num_classes=10000):
confidence = 1.0 - smoothing
smoothed_labels = torch.full(size=(labels.size(0), num_classes),
fill_value=smoothing / (num_classes - 1),
device=labels.device)
smoothed_labels.scatter_(1, labels.unsqueeze(1), confidence)
return smoothed_labels
Attention Dropout targets the attention weights within the self-attention mechanism. Before softmax is applied to the attention scores, dropout is introduced. This can improve robustness and encourage more distributed attention patterns.
Stochastic Depth randomly skips layers during training. This is especially useful in very deep transformer architectures, where it can improve gradient flow and reduce overfitting.
import torch
def stochastic_depth_layer(x, drop_prob, is_training):
if drop_prob == 0.0 or not is_training:
return x
keep_prob = 1.0 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
binary_tensor = torch.floor(random_tensor)
output = x.div(keep_prob) * binary_tensor
return output
Some research papers have experimented with applying dropout after layer normalization. This method is not widely adopted but can be explored in settings where traditional methods underperform.
Applying regularization is a standard part of training or fine-tuning Transformer models effectively. These techniques' optimal combination and configuration depend heavily on the specific application. Based on practical applications, careful use of dropout, label smoothing, and correctly configured weight decay consistently help improve model generalization. Techniques like stochastic depth and attention dropout offer further options, particularly for larger or deeper model configurations.
Effective regularization often distinguishes between a model that merely memorizes the training set and one that performs well on new, unseen data.
© 2025 ApX Machine Learning. All rights reserved.
LangML