Training machine learning models can be a time-consuming process, especially with large datasets or complex architectures. Interruptions, whether due to system crashes, resource limitations, or simply the need to pause and resume later, are common. Checkpointing is the practice of saving the state of your model and training process at regular intervals. This ensures that you can resume training from the last saved point, preventing loss of progress and computational resources. It's also invaluable for accessing intermediate versions of your model, which might perform better on validation data than the final one, especially if overfitting occurs.
A comprehensive checkpoint should allow you to restore the training process as accurately as possible. Merely saving the model's weights (the state_dict
) is often insufficient for seamless resumption. Consider saving the following:
state_dict
: This dictionary contains all the learnable parameters (weights and biases) of your model. You obtain it via model.state_dict()
.state_dict
: Optimizers like Adam or SGD also have internal states (e.g., momentum buffers, learning rates for parameters). Saving this via optimizer.state_dict()
allows the optimizer to pick up exactly where it left off.torch.optim.lr_scheduler
), its state should also be saved using scheduler.state_dict()
to ensure the learning rate continues its schedule correctly upon resumption.Here's a diagram illustrating the typical components of a checkpoint file:
A checkpoint file typically bundles the model's parameters, optimizer state, current epoch, loss values, and optionally, the learning rate scheduler's state.
PyTorch gives you full control over how and when to save checkpoints. This contrasts with TensorFlow's Keras API, where checkpointing is often handled by callbacks like tf.keras.callbacks.ModelCheckpoint
. While Keras callbacks offer convenience, PyTorch's manual approach provides greater flexibility.
Here are common strategies:
This is a straightforward strategy where you save a checkpoint every fixed number of epochs. It ensures you have reasonably frequent backups.
# Inside your training loop
# Assume model, optimizer, epoch, current_loss are defined
SAVE_EVERY_N_EPOCHS = 10
# ... after an epoch completes ...
if (epoch + 1) % SAVE_EVERY_N_EPOCHS == 0:
checkpoint_path = f'./checkpoints/model_epoch_{epoch+1}.pth'
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': current_loss,
# 'scheduler_state_dict': scheduler.state_dict(), # if you use a scheduler
}, checkpoint_path)
print(f"Checkpoint saved at epoch {epoch+1} to {checkpoint_path}")
The main trade-off here is disk space versus the granularity of your backups. Saving too frequently can consume significant storage, while saving too infrequently risks losing more progress.
Often, the primary goal is to save the model that performs best on a validation dataset. This helps prevent saving overfitted models.
# Initialize outside the training loop
best_val_loss = float('inf')
# ...
# Inside your training loop, after validation phase
# Assume val_loss is calculated
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_path = './checkpoints/best_model.pth'
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(), # Optional for best model, but good for fine-tuning
'val_loss': best_val_loss,
}, best_model_path)
print(f"New best model saved at epoch {epoch+1} with validation loss: {best_val_loss:.4f}")
You might choose to save only the model's state_dict
for the "best model" if its primary purpose is inference, but including optimizer and epoch can be useful if you later decide to fine-tune from this best state.
In addition to saving the best model or periodic checkpoints, it's often useful to save the very latest state of the model. This is typically overwritten at each save interval and is used for immediate resumption if training is interrupted.
# Inside your training loop, potentially every epoch or every few epochs
latest_checkpoint_path = './checkpoints/latest_checkpoint.pth'
torch.save({
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': current_loss, # or current_val_loss
# 'scheduler_state_dict': scheduler.state_dict(),
}, latest_checkpoint_path)
print(f"Latest checkpoint saved at epoch {epoch+1}")
You can combine these strategies. For example, save the latest_checkpoint.pth
every epoch, a model_epoch_{N}.pth
every N epochs, and best_model.pth
whenever validation performance improves.
Good organization is important, especially for long experiments.
checkpoint_epoch_050.pth
) or validation metric (e.g., model_val_acc_0.92.pth
) can be very helpful.checkpoints/my_experiment/
) for each experiment's checkpoints.To resume training, you need to load the saved states back into your model, optimizer, and other relevant variables.
# Before starting the training loop, or at the beginning of your script
# Define model, optimizer, and optionally scheduler first
# model = YourModelClass(*args)
# optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
start_epoch = 0
# Path to the checkpoint you want to load
checkpoint_to_load_path = './checkpoints/latest_checkpoint.pth' # Or a specific epoch's checkpoint
if os.path.exists(checkpoint_to_load_path):
checkpoint = torch.load(checkpoint_to_load_path) # Add map_location if necessary
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
last_loss = checkpoint.get('loss', float('inf')) # Use .get for optional keys
# if 'scheduler_state_dict' in checkpoint and scheduler is not None:
# scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
print(f"Resuming training from epoch {start_epoch} with loss {last_loss:.4f}")
else:
print("No checkpoint found, starting training from scratch.")
# Your training loop then starts from start_epoch
# for epoch in range(start_epoch, NUM_EPOCHS):
# # ... training logic ...
When loading a checkpoint, especially if you might move between environments (e.g., GPU-trained model to CPU for inference, or vice-versa), use the map_location
argument in torch.load()
:
torch.load(PATH, map_location=torch.device('cpu'))
to load a GPU-trained model onto a CPU.torch.load(PATH, map_location='cuda:0')
to load onto a specific GPU.After loading the state dicts, remember to call model.train()
if you are resuming training, or model.eval()
if you are loading the model for inference, to set the appropriate mode for layers like dropout and batch normalization.
latest_checkpoint.pth
), it's safer to save to a temporary file first and then atomically rename it to the final path. This prevents file corruption if the script crashes during the save operation.
temp_path = checkpoint_path + ".tmp"
torch.save(state, temp_path)
os.replace(temp_path, checkpoint_path) # os.replace is atomic
By implementing checkpointing strategies, you can make your PyTorch training workflows more resilient and manageable, ensuring that valuable computation time is not lost and that you can always access the most promising versions of your models. This is a fundamental aspect of model development, allowing for more experimentation and safer execution of long-running training jobs.
Was this section helpful?
© 2025 ApX Machine Learning