Masterclass
When implementing checkpointing for long-running training jobs, a primary consideration is how the saving process interacts with the ongoing training computation. Does training pause completely while the checkpoint is written, or can it continue concurrently? This leads to two main approaches: synchronous and asynchronous checkpointing. The choice between them involves balancing simplicity, consistency, and performance overhead.
Synchronous checkpointing is the most straightforward approach. When a checkpoint trigger occurs (e.g., after a certain number of training steps or elapsed time), the training process explicitly pauses all computation. It then gathers the necessary state components model parameters, optimizer states, learning rate scheduler state, current epoch or step number, and potentially the state of the data loader iterators. Once all state is collected, it is serialized and written to persistent storage (like a distributed file system or cloud storage). Only after the write operation completes successfully does the training process resume computation.
In a distributed training setting, synchronous checkpointing requires coordination among all participating workers. Typically, a barrier synchronization is used before saving to ensure all workers have reached the same point. One worker (often rank 0) might be designated to collect state from others or each worker might save its own shard of the state. Another barrier might be used after saving to ensure all workers wait until the checkpoint is fully written before proceeding.
Advantages:
Disadvantages:
Here's a representation of a synchronous checkpoint within a training loop using PyTorch-like syntax in a distributed context:
# Assume setup with torch.distributed initialized
def save_synchronous_checkpoint(
rank, world_size, model, optimizer, scheduler, step, path
):
# Ensure all processes reach this point before saving
if world_size > 1:
torch.distributed.barrier()
if rank == 0: # Rank 0 handles saving the consolidated state
print(
f"Rank {rank}: Starting synchronous checkpoint save at step {step}..."
)
# In a real scenario, state might be gathered from other ranks
# or each rank saves its shard (e.g., with DeepSpeed/FSDP helpers)
state = {
'step': step,
'model_state_dict': model.state_dict(),
# Or model.module.state_dict()
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict()
# Potentially add dataloader state, RNG states etc.
}
torch.save(state, path)
print(f"Rank {rank}: Finished synchronous checkpoint save to {path}.")
else:
# Other ranks wait for rank 0 to finish saving
pass
# Ensure saving is complete on rank 0 before anyone proceeds
if world_size > 1:
torch.distributed.barrier()
# --- Inside the training loop ---
model.train()
for step, batch in enumerate(data_loader):
# Forward pass, backward pass, optimizer step...
outputs = model(batch['input_ids'])
loss = calculate_loss(outputs, batch['labels'])
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# Checkpoint periodically
if step % checkpoint_interval == 0 and step > 0:
checkpoint_path = f"/path/to/checkpoints/step_{step}.pt"
# --- Blocking Save Operation ---
save_synchronous_checkpoint(
rank,
world_size,
model,
optimizer,
scheduler,
step,
checkpoint_path
)
# --- Training resumes only after save is complete ---
# ... rest of the loop (logging, evaluation etc.)
The diagram below illustrates the blocking nature of synchronous checkpointing.
Training halts completely while the checkpoint is saved synchronously across all ranks.
Asynchronous checkpointing aims to mitigate the performance overhead of synchronous saves. The core idea is to decouple the computationally expensive I/O operation of writing the checkpoint from the main training loop.
When a checkpoint trigger occurs, the main training process initiates the save operation but does not wait for it to complete. This is often achieved by:
Advantages:
Disadvantages:
Implementing asynchronous checkpointing often involves using threading or multiprocessing libraries.
import threading
import torch
import time
import os
# Assume torch.distributed is initialized (rank, world_size)
# Placeholder for actual model, optimizer etc.
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x): return self.linear(x)
def state_dict(self): return {'param': torch.randn(10, 10)}
model = DummyModel()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=10, gamma=0.1
)
# Global variable to track the background saving thread
checkpoint_thread = None
def background_save_task(state, path):
"""Function executed by the background thread."""
print(f"Background Saver: Starting async save to {path}...")
try:
# Simulate slow I/O
time.sleep(5) # Simulate saving time
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(state, path)
print(f"Background Saver: Finished async save to {path}.")
except Exception as e:
print(f"Background Saver: Error during checkpointing: {e}")
def save_asynchronous_checkpoint(
rank, world_size, model, optimizer, scheduler, step, path
):
global checkpoint_thread
# Ensure previous background save is complete before starting a new one
if checkpoint_thread is not None and checkpoint_thread.is_alive():
print(f"Rank {rank}: Waiting for previous async checkpoint to finish...")
checkpoint_thread.join() # Wait for the previous thread to complete
if rank == 0: # Rank 0 initiates and manages the save thread
print(
f"Rank {rank}: Initiating asynchronous checkpoint save at step {step}..."
)
# --- Quickly copy state ---
# Use deepcopy if necessary, state_dict() usually returns copies/views
state = {
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict()
}
# --- Launch background thread ---
checkpoint_thread = threading.Thread(
target=background_save_task, args=(state, path)
)
checkpoint_thread.start()
print(f"Rank {rank}: Background save launched. Training continues.")
else:
# Other ranks might need minimal coordination, e.g., ensuring they
# don't proceed too far ahead if consistency is critical,
# but generally they continue training.
# A barrier might be needed *before* copying state if strict consistency
# across ranks for the *copied* state is required.
pass
# --- Training continues immediately on all ranks ---
# Note: No barrier here, allowing overlap
# --- Inside the training loop ---
step = 0
checkpoint_interval = 5 # Example: checkpoint every 5 steps
max_steps = 20
print("Starting Mock Training Loop...")
while step < max_steps:
step += 1
print(f"Main Loop: Training Step {step}")
# Simulate training work
time.sleep(0.5)
# model(...), loss.backward(), optimizer.step()...
if step % checkpoint_interval == 0:
# --- Non-Blocking Save Initiation ---
checkpoint_path = f"/tmp/async_checkpoints/step_{step}.pt"
save_asynchronous_checkpoint(
0, 1, model, optimizer, scheduler, step, checkpoint_path
)
# Assuming rank 0, world_size 1 for simplicity
# Wait for the final checkpoint thread to finish after the loop exits
if checkpoint_thread is not None and checkpoint_thread.is_alive():
print("Main Loop: Waiting for final checkpoint to complete...")
checkpoint_thread.join()
print("Mock Training Loop Finished.")
The diagram below illustrates how asynchronous checkpointing overlaps I/O with computation.
The main training thread only pauses briefly to copy state, then continues computation while the actual save happens in a background thread.
The best approach depends on the specific training setup and priorities:
In practice, for very large models where checkpointing time can be substantial (minutes or longer), asynchronous checkpointing is often favored to maximize the utilization of expensive GPU resources, despite the added implementation complexity. Careful implementation and testing are necessary to ensure the reliability of the asynchronous saving process.
© 2025 ApX Machine Learning