Training deep generative models like diffusion models often involves navigating complex optimization landscapes over extended periods. While sophisticated architectures and loss functions lay the groundwork, achieving stable convergence and optimal performance frequently requires additional techniques to manage the training dynamics. Two widely adopted methods for enhancing training stability are gradient clipping and Exponential Moving Average (EMA) of model weights.
Deep networks, including the U-Nets or Transformers used in diffusion models, can sometimes suffer from the exploding gradient problem. During backpropagation, gradients can accumulate and become excessively large, leading to drastic updates in model weights. These large updates can destabilize training, causing the loss to diverge or oscillate wildly, preventing the model from converging to a good solution.
Gradient clipping directly addresses this by limiting the magnitude of the gradients before they are used to update the model parameters. The most common method is clipping by norm, typically the L2 norm (Euclidean norm).
Let g=∇θL be the gradient of the loss L with respect to the model parameters θ. We calculate the L2 norm of the entire gradient vector across all parameters: ∥g∥=∑igi2. If this norm exceeds a predefined threshold c, the gradient vector is rescaled to have a norm equal to c.
g←{g∥g∥cgif ∥g∥≤cif ∥g∥>cThis ensures that the magnitude of the weight update step is bounded, preventing extreme changes caused by outlier gradients.
torch.nn.utils.clip_grad_norm_
after the loss.backward()
call and before the optimizer.step()
call.# Example snippet (PyTorch)
optimizer.zero_grad()
loss = compute_loss(model, batch)
loss.backward()
# Clip gradients before the optimizer step
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # max_norm is 'c'
optimizer.step()
Gradient clipping acts as a safety mechanism, particularly useful during the initial phases of training or when using high learning rates.
Stochastic Gradient Descent (SGD) and its variants introduce noise into the training process due to the mini-batch nature of updates. While this stochasticity helps escape poor local minima, it also means the model parameters can fluctuate significantly around an optimal point, even late in training. Using the parameters from the very last training step might not yield the best possible model performance, as it could represent a transient, noisy state.
Exponential Moving Average (EMA) provides a way to obtain a more stable and often better-performing set of model weights by averaging the parameters over recent training history. It maintains a "shadow" copy of the model parameters, which are updated slowly based on the current training parameters.
Let θt be the model parameters after the update at step t, and let θt′ be the corresponding EMA parameters. The EMA parameters are updated using a decay factor β:
θt′=βθt−1′+(1−β)θtHere, β is a hyperparameter typically set close to 1 (e.g., 0.99, 0.999, or even 0.9999). A higher β means the EMA weights change more slowly and incorporate a longer history of parameters (more smoothing). A lower β makes the EMA weights track the current training weights more closely.
Diagram illustrating the relationship between regular training weights (θt, blue nodes) updated via optimization steps, and the EMA weights (θt′, green nodes) updated as a weighted average of the previous EMA weight and the current training weight. Dashed lines indicate the influence of the current training weight on the next EMA weight.
# Snippet (often part of a callback or utility class)
ema_decay = 0.999
ema_model_params = [p.clone().detach() for p in model.parameters()]
# Inside training loop, after optimizer.step()
with torch.no_grad():
for ema_p, current_p in zip(ema_model_params, model.parameters()):
ema_p.mul_(ema_decay).add_(current_p, alpha=1 - ema_decay)
# For inference/sampling, load the ema_model_params into the model
# ... load ema params ...
# model.eval()
# generate_samples(model, ...)
Gradient clipping and EMA are often used together. Clipping prevents catastrophic divergence due to isolated large gradients, while EMA smooths out the general stochasticity of the optimization process, leading to a more robust final model state. By employing these techniques, you can significantly improve the reliability of training advanced diffusion models and often achieve superior generative performance.
© 2025 ApX Machine Learning