This section provides a hands-on walkthrough for implementing Automatic Mixed Precision (AMP) training using torch.cuda.amp
. As discussed earlier in the chapter, AMP allows you to leverage lower-precision floating-point formats (like float16
) for certain operations, leading to significant speedups and reduced GPU memory footprint, often with minimal impact on model accuracy. We will convert a standard PyTorch training loop to use AMP, demonstrating the necessary changes.
This practical exercise assumes you have access to a CUDA-enabled NVIDIA GPU with a compute capability of 7.0 or higher (required for efficient float16
tensor core operations) and a reasonably up-to-date PyTorch installation (version 1.6 or later).
Let's start with a simplified standard training loop in full precision (float32
). We'll use a basic convolutional network and random data for demonstration purposes.
import torch
import torch.nn as nn
import torch.optim as optim
import time
import contextlib # Used for timing context manager
# 1. Define a simple model
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
# Flatten features for the linear layer
self.fc = nn.Linear(64 * 16 * 16, num_classes) # Assuming 32x32 input images
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = torch.flatten(x, 1) # Flatten all dimensions except batch
x = self.fc(x)
return x
# 2. Setup: Device, Model, Data, Loss, Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Dummy data parameters
batch_size = 64
img_size = 32
num_batches = 100
# Simple timer context manager
@contextlib.contextmanager
def measure_time():
start_time = time.time()
yield
end_time = time.time()
print(f"Elapsed time: {end_time - start_time:.4f} seconds")
# Simple memory usage reporter
def report_memory(stage=""):
if torch.cuda.is_available():
print(f"{stage} - Peak memory allocated: {torch.cuda.max_memory_allocated(device) / 1e6:.2f} MB")
torch.cuda.reset_peak_memory_stats(device) # Reset peak counter for next measurement
# 3. Standard Training Loop (FP32)
print("\n--- Standard FP32 Training ---")
report_memory("Before training")
model.train()
with measure_time():
for i in range(num_batches):
# Generate dummy data on the fly
inputs = torch.randn(batch_size, 3, img_size, img_size, device=device)
labels = torch.randint(0, 10, (batch_size,), device=device)
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass and optimization
loss.backward()
optimizer.step()
if (i + 1) % 20 == 0:
print(f"Batch [{i+1}/{num_batches}], Loss: {loss.item():.4f}")
report_memory("After training")
print("--- Standard FP32 Training Complete ---")
Run this code (if you have a suitable GPU). Take note of the reported elapsed time and peak memory usage. This serves as our baseline.
torch.cuda.amp
Now, let's modify the loop to incorporate AMP. This requires two primary components from torch.cuda.amp
:
autocast
: This is a context manager that enables automatic casting of tensor operations to lower-precision types (float16
by default on compatible GPUs) where beneficial and safe. Operations like convolutions and linear layers often see significant speedups, while others like reductions might remain in float32
for numerical stability.GradScaler
: Since float16
has a much smaller numerical range than float32
, gradients computed during the backward pass can become very small (underflow) and get flushed to zero, hindering training. GradScaler
helps prevent this by scaling the loss value upwards before the backward pass. This effectively scales the resulting gradients into the representable range of float16
. Before the optimizer updates the weights, GradScaler
then unscales the gradients back to their original values. If any non-finite gradients (NaN or Inf) are detected during unscaling (which can sometimes happen with unstable training or high loss scaling factors), the optimizer step for that batch is skipped. GradScaler
also dynamically adjusts the scaling factor over time.Here’s how we modify the training loop:
import torch
import torch.nn as nn
import torch.optim as optim
import time
import contextlib # Used for timing context manager
from torch.cuda.amp import GradScaler, autocast
# --- Re-initialize model and optimizer for a fair comparison ---
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# --- Keep criterion and dummy data parameters the same ---
criterion = nn.CrossEntropyLoss()
batch_size = 64
img_size = 32
num_batches = 100
# --- Use the same timer and memory reporter ---
@contextlib.contextmanager
def measure_time():
start_time = time.time()
yield
end_time = time.time()
print(f"Elapsed time: {end_time - start_time:.4f} seconds")
def report_memory(stage=""):
if torch.cuda.is_available():
print(f"{stage} - Peak memory allocated: {torch.cuda.max_memory_allocated(device) / 1e6:.2f} MB")
torch.cuda.reset_peak_memory_stats(device)
print("\n--- Mixed-Precision (AMP) Training ---")
# 1. Initialize GradScaler
scaler = GradScaler()
report_memory("Before training")
model.train()
with measure_time():
for i in range(num_batches):
inputs = torch.randn(batch_size, 3, img_size, img_size, device=device)
labels = torch.randint(0, 10, (batch_size,), device=device)
optimizer.zero_grad()
# 2. Wrap the forward pass with autocast
# Operations inside this context run in lower precision (FP16) where supported
with autocast():
outputs = model(inputs)
loss = criterion(outputs, labels)
# 3. Scale the loss before backward()
# scaler.scale multiplies the loss by the current scale factor
scaler.scale(loss).backward()
# 4. scaler.step() unscales gradients and calls optimizer.step()
# Unscales gradients located in optimizer.param_groups[...].grad
# If gradients aren't finite (NaN/Inf), optimizer.step() is skipped
scaler.step(optimizer)
# 5. Update the scale factor for the next iteration
# Decreases scale factor if NaNs/Infs were found, otherwise possibly increases it
scaler.update()
if (i + 1) % 20 == 0:
# Note: loss.item() is still the unscaled loss value
print(f"Batch [{i+1}/{num_batches}], Loss: {loss.item():.4f}")
report_memory("After training")
print("--- Mixed-Precision (AMP) Training Complete ---")
If you run both versions on a compatible GPU (especially one with Tensor Cores, like V100, T4, A100, H100, or RTX series 20xx and later), you should observe:
float16
tensors require half the memory bandwidth and storage compared to float32
. Activations stored for the backward pass also consume less memory, allowing for larger batch sizes or models.GradScaler
, wrapping the forward pass with autocast
, and modifying the backward()
and optimizer.step()
calls to use the scaler
.GradScaler
, the training process usually remains numerically stable, converging similarly to the FP32 baseline. You might observe slightly different loss values batch-to-batch due to the precision change, but the overall training dynamics are generally preserved.Here's a conceptual visualization of typical results comparing FP32 and AMP training:
Illustrative comparison showing typical speedup and memory reduction when using AMP compared to standard FP32 training. Actual results will vary based on hardware and model.
bfloat16
: On newer hardware (e.g., Ampere architecture GPUs like A100, or newer TPUs), you might prefer torch.bfloat16
. It has the same exponent range as float32
but lower mantissa precision. This often makes it more resilient to underflow/overflow issues, potentially eliminating the need for GradScaler
. You can enable it via autocast(dtype=torch.bfloat16)
. Check your hardware documentation for optimal settings.scaler.unscale_(optimizer)
can be called before torch.nn.utils.clip_grad_norm_
or torch.nn.utils.clip_grad_value_
.This practical exercise demonstrates how easily you can integrate Automatic Mixed Precision into your PyTorch training loops. By leveraging torch.cuda.amp
, you can significantly accelerate training and reduce memory consumption, making it possible to train larger models or use larger batch sizes on existing hardware.
© 2025 ApX Machine Learning