Masterclass
Deciding how often to save checkpoints and how to manage the resulting files involves balancing several competing factors: the risk of losing computation due to failures, the overhead introduced by the saving process itself, and the cost and availability of storage. Getting this balance right is important for efficient and reliable large-scale model training.
The core trade-off when choosing checkpoint frequency is between minimizing potential work lost upon failure and minimizing the overhead incurred during saving.
Several factors influence the optimal frequency for your specific setup:
Common strategies for triggering checkpoints include:
N
training steps. This provides predictable intervals in terms of training progress.
# Example within a PyTorch training loop
import torch
import os
SAVE_EVERY_N_STEPS = 1000
checkpoint_dir = "/path/to/checkpoints"
global_step = 0 # Assuming this counter increments each training step
# Inside your training loop...
# optimizer.step()
# scheduler.step()
global_step += 1
if global_step % SAVE_EVERY_N_STEPS == 0:
# Construct checkpoint state dictionary
state = {
'step': global_step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
# Add scheduler state, random states, etc.
}
checkpoint_path = os.path.join(checkpoint_dir, f"step_{global_step}.pt")
print(f"Saving checkpoint to {checkpoint_path} at step {global_step}")
# In a real scenario, use a robust saving function, potentially async
# torch.save(state, checkpoint_path)
# Consider adding storage management logic here (see below)
pass # Placeholder for actual saving and management
H
hours. This is simple to implement but less predictable regarding the amount of training progress between checkpoints, as iteration speed can vary.N
steps or every H
hours, whichever occurs first. This offers a safety net against both very slow progress and long periods without saves.Experimentation is often needed. Start with a reasonable frequency (e.g., every 1000-5000 steps or every 1-2 hours) and adjust based on observed stability and measured overhead.
LLM checkpoints, containing model weights, optimizer states, and potentially gradient statistics (especially with ZeRO Stage 3), can be very large, ranging from gigabytes to terabytes depending on the model size and distributed training strategy. Storing every single checkpoint created during a long training run is usually impractical due to storage costs and capacity limits.
Storage Location Trade-offs:
Retention Policies:
Because storing all checkpoints is infeasible, you need a strategy to decide which ones to keep and which to discard.
K
: Retain the K
most recent checkpoints. When saving checkpoint N
, delete checkpoint N-K
.
K
appropriately (e.g., K=3
or K=5
). Still primarily time-based, not necessarily performance-based.M
based on Validation: Monitor a validation metric (e.g., perplexity) periodically. Save checkpoints associated with the M
best validation scores observed so far. This often complements keeping the latest checkpoint(s).
M
best-performing checkpoints based on validation metrics.Implementing retention often involves listing existing checkpoints in the storage location, sorting them based on the chosen criteria (step number, timestamp, validation score), and deleting the ones that fall outside the retention window.
# Example logic for keeping the last K checkpoints
import os
import glob
import re
checkpoint_dir = "/path/to/checkpoints"
KEEP_LAST_K = 3
def manage_checkpoints(checkpoint_dir, keep_last_k):
"""Removes older checkpoints, keeping only the specified number."""
checkpoints = glob.glob(os.path.join(checkpoint_dir, "step_*.pt"))
# Extract step numbers, handling potential non-matching files
steps = []
for ckpt in checkpoints:
match = re.search(r"step_(\d+)\.pt$", os.path.basename(ckpt))
if match:
steps.append((int(match.group(1)), ckpt))
# Sort by step number (descending)
steps.sort(key=lambda x: x[0], reverse=True)
# Identify checkpoints to delete
if len(steps) > keep_last_k:
checkpoints_to_delete = [ckpt_path for step, ckpt_path in steps[keep_last_k:]]
print(f"Found {len(steps)} checkpoints. Deleting {len(checkpoints_to_delete)} older checkpoints.")
for ckpt_path in checkpoints_to_delete:
try:
os.remove(ckpt_path)
print(f"Deleted {ckpt_path}")
except OSError as e:
print(f"Error deleting {ckpt_path}: {e}")
# Call this function after successfully saving a new checkpoint
# manage_checkpoints(checkpoint_dir, KEEP_LAST_K)
Flow demonstrating a checkpoint retention policy check after saving a new checkpoint.
Ultimately, the choice of frequency and storage management depends on a careful assessment of your training environment's stability, performance characteristics, computational costs, and tolerance for potential data loss. Using robust, potentially asynchronous saving mechanisms combined with a well-defined retention policy based on both recency and performance is a standard practice for large-scale LLM training.
© 2025 ApX Machine Learning