Masterclass
Training large language models (LLMs) is not a quick process. Unlike smaller models that might train in hours, pre-training a state-of-the-art LLM often requires operating large clusters of accelerators (GPUs or TPUs) continuously for days, weeks, or even months. Consider a hypothetical training run consuming 1024 high-end GPUs for 30 days. This represents over 730,000 accelerator-hours. Such prolonged, large-scale operations significantly increase the exposure to potential interruptions.
The reality of large-scale distributed systems is that failures happen. Over extended periods, the probability of encountering an issue approaches certainty. These interruptions can stem from various sources:
Without a mechanism to save progress periodically, any such interruption forces the entire training process to restart from the very beginning. The consequences of this are severe:
This is where checkpointing becomes indispensable. Checkpointing is the practice of periodically saving the complete state of the training job to persistent storage (like a distributed file system or cloud storage). This state includes not just the model's parameters (weights), but also everything needed to resume training exactly where it left off, such as the optimizer's state, the learning rate scheduler's state, the current training iteration or epoch number, and potentially the state of the data loaders.
Imagine a simplified training loop:
# Example WITHOUT checkpointing
import torch
import torch.optim as optim
model = MyLargeModel()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
# Assume data_loader provides batches of data
for step in range(TOTAL_TRAINING_STEPS):
# --- Potential failure point ---
if some_failure_condition():
print("Failure occurred! Restarting from step 0.")
# All progress up to 'step' is lost.
# Need to re-initialize model, optimizer, and start from step 0.
raise SystemExit("Training failed")
batch = next(iter(data_loader))
outputs = model(batch['input_ids'])
loss = calculate_loss(outputs, batch['labels'])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % LOG_INTERVAL == 0:
print(f"Step: {step}, Loss: {loss.item()}")
print("Training finished successfully!") # Only reached if no failures occur
If a failure occurs at step 500,000 in a million-step training run, all the work done for those 500,000 steps is wasted. Checkpointing introduces save points:
# Example WITH checkpointing
import torch
import torch.optim as optim
import os
CHECKPOINT_DIR = "/path/to/persistent/storage/checkpoints"
CHECKPOINT_FREQ = 1000 # Save every 1000 steps
def save_checkpoint(model, optimizer, step, filename="checkpoint.pt"):
checkpoint_path = os.path.join(CHECKPOINT_DIR, f"step_{step}_{filename}")
state = {
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
# Add scheduler state, random states, etc.
}
torch.save(state, checkpoint_path)
print(f"Saved checkpoint at step {step} to {checkpoint_path}")
def load_checkpoint(model, optimizer):
# Logic to find the latest checkpoint
latest_checkpoint_path = find_latest_checkpoint(CHECKPOINT_DIR)
if latest_checkpoint_path:
checkpoint = torch.load(latest_checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_step = checkpoint['step'] + 1
print(f"Resumed from checkpoint at step {start_step}")
return start_step
else:
print("No checkpoint found, starting from scratch.")
return 0
model = MyLargeModel()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
start_step = load_checkpoint(model, optimizer) # Try to resume
for step in range(start_step, TOTAL_TRAINING_STEPS):
# --- Potential failure point ---
try:
batch = next(iter(data_loader))
outputs = model(batch['input_ids'])
loss = calculate_loss(outputs, batch['labels'])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % LOG_INTERVAL == 0:
print(f"Step: {step}, Loss: {loss.item()}")
# --- Save progress periodically ---
if step % CHECKPOINT_FREQ == 0 and step > 0:
save_checkpoint(model, optimizer, step)
except Exception as e:
print(f"Failure occurred at step {step}: {e}")
print("Exiting. Rerun script to resume from the latest checkpoint.")
raise SystemExit("Training interrupted")
print("Training finished successfully!")
In this revised loop, if a failure occurs, the load_checkpoint
function (whose implementation details we'll discuss later) can restore the state from the most recently saved checkpoint, allowing training to resume from, for example, step 500,001 instead of step 0. This drastically reduces the amount of wasted computation.
While the primary motivation for checkpointing is fault tolerance against unexpected failures, it also provides operational flexibility. Checkpoints allow for planned shutdowns, perhaps for scheduled cluster maintenance or to reconfigure the training job. They also enable resuming training on different hardware or exploring different training paths by branching off from an intermediate state.
Given the significant investments in time and resources required for LLM training, robust checkpointing is not merely a convenience; it is a fundamental requirement for successfully completing these demanding projects. The subsequent sections will detail the components that need saving, strategies for managing checkpoints in distributed settings, and best practices for implementation.
© 2025 ApX Machine Learning