Training modern deep learning models often pushes the limits of your hardware, both in terms of computation time and memory. One effective technique to alleviate these pressures is mixed precision training. This approach combines the use of lower-precision floating-point numbers (like 16-bit floats) for certain parts of your model with higher-precision numbers (like 32-bit floats) for others, aiming to speed up training and reduce memory usage without a significant loss in model accuracy.
Traditionally, most neural network training has been performed using 32-bit floating-point numbers (FP32 or single-precision). While FP32 offers a wide dynamic range and good precision, computations with it can be slower and require more memory compared to lower-precision formats.
Mixed precision training intelligently uses a combination of:
The goal is to reap the benefits of FP16 (speed, memory) while mitigating its potential downsides, such as a smaller representable range which can lead to overflow (numbers becoming too large) or underflow (gradients becoming zero).
torch.cuda.amp
PyTorch provides convenient tools for automatic mixed precision (AMP) training primarily through the torch.cuda.amp
module. This module automates most of the process, making it relatively straightforward to integrate into your existing training scripts. The two main components you'll interact with are autocast
and GradScaler
.
torch.cuda.amp.autocast
The autocast
context manager is the workhorse for selecting which operations run in FP16 and which remain in FP32. When you enable autocast
for a region of your code (typically the forward pass), it automatically casts inputs of eligible PyTorch operations to FP16.
Which operations are eligible?
BatchNorm
), are often kept in FP32 to maintain precision.autocast
handles these conversions on the fly. For example:
# model and data are on CUDA
model = MyModel().cuda()
input_data = torch.randn(N, C, H, W, device="cuda")
# Enable autocast for the forward pass
with torch.cuda.amp.autocast():
output = model(input_data)
loss = loss_fn(output, target) # loss computation also under autocast
# Gradients from this loss will be FP16
# loss.backward() # (we'll see how GradScaler modifies this)
Inside the autocast
block, operations like model(input_data)
will internally use FP16 for many computations if input_data
is a CUDA tensor and the operations are deemed safe and efficient in FP16 by PyTorch. The loss
tensor produced will also likely be FP16.
torch.cuda.amp.GradScaler
While autocast
handles the forward pass, using FP16 for gradients during the backward pass can lead to underflow. Gradients, especially for deep networks or small parameter updates, can become very small. If these small values fall below the minimum representable positive number in FP16, they become zero, effectively halting learning for those parameters.
GradScaler
helps prevent this by scaling the loss before the backward pass. Here's the process:
autocast
) is multiplied by a large scaling factor. This inflates the magnitude of the subsequent gradients.backward()
call is performed on this scaled loss. The resulting gradients are also scaled. Because they are larger, they are less likely to underflow in FP16.GradScaler
un-scales the gradients by dividing them by the same scaling factor. This restores them to their correct magnitude.GradScaler
dynamically adjusts the scaling factor. If gradients overflow (become inf
or NaN
) during a step, it means the scaling factor was too high, so it's reduced for the next iteration, and the optimizer step for the current iteration is skipped. If no overflows occur for a certain number of steps, the scaling factor may be increased to further improve precision.Let's see how to modify a standard PyTorch training loop to use AMP.
Typical Training Loop (FP32):
import torch
import torch.nn as nn
import torch.optim as optim
# Assume model, data_loader, loss_fn are defined
# model = MyModel().cuda()
# optimizer = optim.Adam(model.parameters(), lr=1e-3)
# loss_fn = nn.CrossEntropyLoss()
# for epoch in range(num_epochs):
# for input_batch, target_batch in data_loader:
# input_batch, target_batch = input_batch.cuda(), target_batch.cuda()
# optimizer.zero_grad()
# outputs = model(input_batch)
# loss = loss_fn(outputs, target_batch)
# loss.backward()
# optimizer.step()
# print(f"Epoch {epoch+1} completed.")
Training Loop with PyTorch AMP:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler # Import AMP components
# Assume model, data_loader, loss_fn are defined
# model = MyModel().cuda()
# optimizer = optim.Adam(model.parameters(), lr=1e-3)
# loss_fn = nn.CrossEntropyLoss()
scaler = GradScaler() # Initialize GradScaler
# for epoch in range(num_epochs):
# for input_batch, target_batch in data_loader:
# input_batch, target_batch = input_batch.cuda(), target_batch.cuda()
# optimizer.zero_grad()
# # Forward pass with autocasting
# with autocast():
# outputs = model(input_batch)
# loss = loss_fn(outputs, target_batch)
# # Scale loss and call backward() on scaled loss
# scaler.scale(loss).backward()
# # Unscale gradients and call optimizer.step()
# scaler.step(optimizer)
# # Update the scale for next iteration
# scaler.update()
# print(f"Epoch {epoch+1} completed with AMP.")
The main changes are:
GradScaler()
.with autocast():
block.scaler.scale(loss).backward()
instead of just loss.backward()
.scaler.step(optimizer)
instead of optimizer.step()
.scaler.update()
after scaler.step(optimizer)
.The optimizer.zero_grad()
call can be placed before the autocast
block or at the beginning of the loop as usual. PyTorch recommends setting gradients to None
instead of zeroing them for a small performance benefit, which can be done via optimizer.zero_grad(set_to_none=True)
.
The primary benefits of AMP are reduced training time and memory usage. While exact numbers vary by model, GPU, and batch size, the improvements can be substantial.
This chart illustrates potential reductions in training time and peak memory usage when switching from standard FP32 training to Automatic Mixed Precision (AMP).
If you've used TensorFlow, you might be familiar with tf.keras.mixed_precision
. The concept is very similar:
Policy
(e.g., mixed_precision.set_global_policy('mixed_float16')
) to define how layers should handle mixed precision, somewhat analogous to how autocast
implicitly determines types.LossScaleOptimizer
wraps an existing optimizer to perform loss scaling, akin to PyTorch's GradScaler
.Both frameworks aim to simplify the adoption of mixed precision by automating type casting and loss scaling. The underlying principles are the same, though the specific API calls and implementation details differ. PyTorch's autocast
and GradScaler
offer a flexible way to apply mixed precision, often requiring only a few lines of code modification.
autocast
can also target if dtype=torch.bfloat16
is specified. BF16 has a similar range to FP32 but less precision, often offering a good balance without requiring loss scaling as frequently as FP16.BatchNorm
typically maintain their weights and perform computations in FP32 even within an autocast
block to ensure stability. PyTorch handles this automatically.GradScaler
but before optimizer.step()
.
# scaler.scale(loss).backward()
# # Unscale the gradients of optimizer's assigned params in-place
# scaler.unscale_(optimizer)
# # Clip gradients after unscaling
# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
# scaler.step(optimizer)
# scaler.update()
GradScaler
:
# Saving
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scaler_state_dict': scaler.state_dict(),
# ... other things
}
torch.save(checkpoint, 'my_checkpoint.pth')
# Loading
# checkpoint = torch.load('my_checkpoint.pth')
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# scaler.load_state_dict(checkpoint['scaler_state_dict'])
autocast
for certain parts of your model or explicitly casting specific operations to FP32.By leveraging PyTorch's torch.cuda.amp
, you can often achieve significant training speedups and memory savings with minimal code changes, allowing you to train larger models or iterate faster on your existing ones. This is a valuable tool for any TensorFlow developer transitioning to PyTorch and looking to optimize their training workflows.
Was this section helpful?
© 2025 ApX Machine Learning