Training deep learning models can often be a time-consuming process, potentially taking hours or even days depending on the complexity of the model and the size of the dataset. It's impractical to restart training from scratch every time you need to stop, whether due to system interruptions, the need to fine-tune later, or simply wanting to use the trained model for predictions. This is where saving and loading model checkpoints becomes essential.
A checkpoint captures the state of your training process at a specific moment, allowing you to restore it later. This section covers how to effectively save and load the necessary components of your PyTorch models and training state.
When saving a checkpoint, you need to decide what information is necessary for your purposes. Generally, you'll want to save at least the model's parameters. If you intend to resume training, you should also save the state of the optimizer and potentially other metadata like the current epoch number and the latest validation loss.
PyTorch models have an internal state dictionary, accessed via model.state_dict()
, which contains all the learned parameters (weights and biases) of the model's layers. This is the recommended way to save the model's learned information.
Why save the state_dict
instead of the entire model object (e.g., torch.save(model, PATH)
)? Saving the state_dict
is more flexible and less prone to issues. Pickling the entire model object saves the specific code structure used at the time of saving. If you refactor or change the model's class definition later, loading the pickled object might fail or lead to unexpected behavior. Saving just the state dictionary separates the learned parameters from the code structure, making loading more reliable.
Similarly, optimizers like Adam or SGD also have internal states (e.g., momentum buffers, adaptive learning rates) that evolve during training. To precisely resume training, you should save the optimizer's state using optimizer.state_dict()
.
torch.save
PyTorch uses torch.save()
to serialize and save objects. To save a checkpoint, you typically create a dictionary containing the model's state dictionary, the optimizer's state dictionary, and any other relevant information, then save this dictionary.
Here's a common pattern for saving a checkpoint within your training loop:
# Assume model, optimizer, epoch, loss are defined
# PATH = "path/to/your/checkpoint.pth" # Define your save path
checkpoint = {
'epoch': epoch + 1, # Save the next epoch number to start from
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss, # Or maybe validation loss
# Add any other metrics or info you want to save
# 'validation_accuracy': val_acc,
}
torch.save(checkpoint, PATH)
print(f"Checkpoint saved at epoch {epoch} to {PATH}")
You might save a checkpoint periodically (e.g., every 10 epochs) or whenever the model achieves a new best performance on the validation set.
torch.load
and load_state_dict
To load a checkpoint, you first use torch.load()
to deserialize the saved dictionary from the file. Then, you need to load the state dictionaries back into your model and optimizer instances.
Important: You must create instances of your model and optimizer before you can load their states. The load_state_dict()
method loads the parameters into an existing object; it doesn't recreate the object itself.
If you only need the model for making predictions (inference) and don't plan to resume training, you typically only need to load the model_state_dict
.
# First, instantiate your model structure
model = YourModelClass(*args, **kwargs)
# Define the path to your saved checkpoint
PATH = "path/to/your/checkpoint.pth"
# Load the checkpoint dictionary
checkpoint = torch.load(PATH)
# Load the model state dictionary from the checkpoint
model.load_state_dict(checkpoint['model_state_dict'])
# Set the model to evaluation mode
model.eval()
# Now the model is ready for inference
# with torch.no_grad():
# outputs = model(inputs)
Setting model.eval()
is important as it disables layers like Dropout and normalizes Batch Normalization layers using running statistics, which is the correct behavior during inference.
If you want to continue training from where you left off, you need to load the states for both the model and the optimizer, as well as retrieve other saved metadata like the epoch number.
# Instantiate model and optimizer first
model = YourModelClass(*args, **kwargs)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Or your chosen optimizer
# Define the path
PATH = "path/to/your/checkpoint.pth"
start_epoch = 0
best_loss = float('inf') # Example: initialize best loss
# Check if checkpoint exists to load from
if os.path.exists(PATH):
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
best_loss = checkpoint['loss'] # Load previous loss
print(f"Checkpoint loaded. Resuming training from epoch {start_epoch}")
# Set the model to training mode
model.train()
# Now you can continue your training loop, starting from start_epoch
# for epoch in range(start_epoch, num_epochs):
# # ... training steps ...
Setting model.train()
ensures layers like Dropout and Batch Normalization behave correctly during training.
Sometimes, you might save a model trained on a GPU and later need to load it on a machine with only a CPU, or vice-versa. By default, torch.load()
will try to load tensors onto the device where they were saved. This can cause errors if that device isn't available.
To handle this, you can use the map_location
argument in torch.load()
.
# Load a GPU-trained model onto a CPU
checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
# Load any model onto the currently available device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
# Remember to also move your model to the device
model.to(device)
Saving and loading checkpoints is a fundamental part of the deep learning workflow. By mastering these techniques using torch.save
, torch.load
, and load_state_dict
, you ensure your progress is safe, your models are reusable, and your training process is robust against interruptions. Remember to save the state_dict
for both the model and optimizer, along with relevant metadata, for maximum flexibility and reliability.
© 2025 ApX Machine Learning