Full parameter fine-tuning often involves training for extended periods, potentially hours, days, or even weeks, consuming significant computational resources. Interruptions during these long runs are almost inevitable, whether due to hardware failures, cluster job preemption, network issues, or simply the need to pause and restart. Without a mechanism to save and restore progress, such interruptions could lead to a complete loss of the training performed up to that point, wasting valuable time and compute budget. Checkpointing provides this essential safety net.
Effective checkpointing involves periodically saving the complete state of the training process to persistent storage. This allows you to resume training precisely from where it left off, minimizing wasted effort.
To resume training seamlessly, you need to save more than just the model's parameters (θ). A comprehensive checkpoint should capture:
model.state_dict()
. Saving this ensures you retain the learned knowledge.optimizer.state_dict()
is necessary to preserve this momentum.scheduler.state_dict()
. Failing to do so would reset the schedule, leading to incorrect learning rates upon resumption.torch.get_rng_state()
, numpy.random.get_state()
, random.getstate()
) can be beneficial. Restoring these ensures the data pipeline behaves identically after resumption.Checkpointing logic is typically integrated directly into the training loop. You need to decide on a checkpointing frequency. Common strategies include:
To manage storage, you might keep only the latest checkpoint, the best checkpoint, or a rolling window of the last few checkpoints.
Here's a conceptual PyTorch-like example of saving a checkpoint:
# Assume model, optimizer, scheduler, epoch, global_step are defined
checkpoint_path = f"./model_checkpoint_epoch_{epoch}_step_{global_step}.pt"
save_state = {
'epoch': epoch,
'global_step': global_step,
'model_state_dict': model.state_dict(), # Or model.module.state_dict() if using DDP
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
# Optionally add RNG states, config, loss history etc.
# 'rng_state': torch.get_rng_state(),
# 'config': training_args,
}
# Best practice: save to temporary file then rename for atomicity
temp_path = checkpoint_path + ".tmp"
torch.save(save_state, temp_path)
os.rename(temp_path, checkpoint_path) # Atomic rename
print(f"Checkpoint saved to {checkpoint_path}")
# Optionally, manage older checkpoints (e.g., keep only last 3)
# manage_checkpoints(checkpoint_dir, keep_last=3)
Using a temporary file and renaming ensures that if the saving process is interrupted, you don't end up with a corrupted partial checkpoint file.
When starting a training run, your script should check if a valid checkpoint exists. If so, it should load the saved state before commencing training.
Here's how you might load the state:
# Assume model, optimizer, scheduler are initialized
# Determine the checkpoint path to load from (e.g., latest one)
checkpoint_path = find_latest_checkpoint("./") # Function to find the checkpoint file
if checkpoint_path:
print(f"Resuming training from checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu') # Load to CPU first
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch']
global_step = checkpoint['global_step']
# Optionally restore RNG states, etc.
# torch.set_rng_state(checkpoint['rng_state'])
# Important: Move optimizer state to the correct device(s) if using GPU
# This step might be needed depending on how the optimizer state was saved
# and if you loaded the checkpoint to CPU initially.
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to(device) # device could be 'cuda:0'
print(f"Resumed from epoch {start_epoch}, global step {global_step}")
else:
print("No checkpoint found, starting training from scratch.")
start_epoch = 0
global_step = 0
# --- Start training loop from start_epoch, tracking global_step ---
Loading the optimizer and scheduler states is essential for maintaining the training dynamics. Loading the checkpoint to the CPU first (map_location='cpu'
) before loading into the model/optimizer can prevent GPU memory issues if the checkpoint is large. If using GPUs, ensure the optimizer state tensors are moved to the correct device after loading.
When using distributed training frameworks like PyTorch's Distributed Data Parallel (DDP), checkpointing requires slight modifications:
model.module.state_dict()
for DDP).torch.distributed.barrier()
) to ensure all processes have loaded the state before proceeding with training.Training loop incorporating checkpoint loading at the start and periodic saving during the process. Interruptions trigger a restart, which then loads the most recent checkpoint.
Mastering checkpointing and resumption is not just a convenience but a necessity for reliably executing the resource-intensive process of full parameter fine-tuning on large language models. It ensures that progress is preserved against interruptions, making long training runs feasible and manageable.
© 2025 ApX Machine Learning