Now that you're familiar with the different strategies for saving and preparing PyTorch models, it's time to put this knowledge into practice. This section will guide you through coding exercises demonstrating how to save and load model states, manage checkpoints, and convert models to TorchScript for more flexible deployment. We'll use a simple neural network as our test case, allowing us to focus on the mechanics of model persistence and serialization.
First, let's set up our environment and define the model we'll be working with throughout these examples.
import torch
import torch.nn as nn
import torch.optim as optim
import os
# Define a directory to save our model files
MODEL_SAVE_DIR = "saved_models_practice"
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
# Define a simple neural network
class SimpleNet(nn.Module):
def __init__(self, input_size=10, hidden_size=5, output_size=2):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Helper function to print model parameters (for verification)
def print_model_parameters(model):
for name, param in model.named_parameters():
if param.requires_grad:
print(f"{name}: {param.data.numpy().sum():.4f}") # Print sum of weights for brevity
This SimpleNet
is a basic two-layer fully connected network. We'll use it to illustrate various saving and loading techniques.
The recommended approach for persisting PyTorch models is to save the model's state_dict
. This dictionary object maps each layer to its learnable parameters (weights and biases). It's lightweight and more robust to code changes compared to saving the entire model object.
Saving the state_dict
Let's instantiate our model and save its parameters.
# Instantiate the model
model_state_dict_example = SimpleNet()
print("Initial parameters (sum):")
print_model_parameters(model_state_dict_example)
# Define a path for saving the state_dict
STATE_DICT_PATH = os.path.join(MODEL_SAVE_DIR, "simplenet_state_dict.pth")
# Save the model's state_dict
torch.save(model_state_dict_example.state_dict(), STATE_DICT_PATH)
print(f"\nModel state_dict saved to {STATE_DICT_PATH}")
The state_dict()
method returns a Python dictionary containing all the weights and biases. torch.save()
then serializes this dictionary to disk. The .pth
extension is a common convention for PyTorch model files.
Loading the state_dict
To load the parameters, you first need an instance of your model structure. Then, you load the saved state_dict
and apply it to this model instance.
# Create a new instance of the model
loaded_model_state_dict = SimpleNet()
print("\nParameters of new model instance (before loading, sum):")
print_model_parameters(loaded_model_state_dict)
# Load the saved state_dict
state_dict = torch.load(STATE_DICT_PATH)
# Apply the loaded state_dict to the model
loaded_model_state_dict.load_state_dict(state_dict)
print("\nParameters of model after loading state_dict (sum):")
print_model_parameters(loaded_model_state_dict)
# Remember to call model.eval() if you are using the model for inference
loaded_model_state_dict.eval()
Notice how the parameters of loaded_model_state_dict
match those of the original model_state_dict_example
after loading. It's important that the model architecture defined in your script matches the architecture used when the state_dict
was saved.
PyTorch also allows you to save the entire model object using Python's pickle
module. While this is convenient, it can be less portable because it ties the saved file to the specific class structure and directory path used during saving. If you refactor your code or move files, you might encounter issues loading the model.
Saving the Entire Model
# Instantiate another model
model_full_save_example = SimpleNet(input_size=10, hidden_size=8, output_size=3) # Different architecture for clarity
model_full_save_example.fc1.weight.data.fill_(0.5) # Modify weights for differentiation
print("\nParameters of model for full save (sum):")
print_model_parameters(model_full_save_example)
# Define a path for saving the full model
FULL_MODEL_PATH = os.path.join(MODEL_SAVE_DIR, "simplenet_full_model.pth")
# Save the entire model
torch.save(model_full_save_example, FULL_MODEL_PATH)
print(f"\nEntire model saved to {FULL_MODEL_PATH}")
Loading the Entire Model
Loading a fully saved model is straightforward, as torch.load()
directly returns the model object.
# Load the entire model
loaded_full_model = torch.load(FULL_MODEL_PATH)
print("\nParameters of fully loaded model (sum):")
print_model_parameters(loaded_full_model)
# Set to evaluation mode
loaded_full_model.eval()
# You can directly use it for inference
dummy_input = torch.randn(1, 10) # Original input_size was 10 for this model
# If loaded_full_model was SimpleNet(input_size=10, hidden_size=8, output_size=3)
# dummy_input should have 10 features
# output = loaded_full_model(dummy_input)
# print(f"\nOutput from fully loaded model: {output}")
This method is simpler but, as mentioned, less flexible for long-term storage or sharing across different projects or Python environments. The class definition for SimpleNet
must be available and accessible in the environment where you load the model.
Checkpointing involves saving the model's state (and potentially other training information like optimizer state and epoch number) at various points during a lengthy training process. This allows you to resume training if it's interrupted or to retrieve the model at its best performing state.
Let's simulate a basic checkpointing mechanism.
# Setup for a mock training loop
model_for_checkpointing = SimpleNet()
optimizer = optim.Adam(model_for_checkpointing.parameters(), lr=0.001)
num_epochs_mock = 5
current_epoch = 0
CHECKPOINT_PATH = os.path.join(MODEL_SAVE_DIR, "simplenet_checkpoint.pth")
print("\nSimulating training and checkpointing...")
for epoch in range(num_epochs_mock):
current_epoch = epoch
# Simulate some training (e.g., updating weights manually for this example)
for param in model_for_checkpointing.parameters():
if param.requires_grad:
param.data += 0.01 * (epoch + 1) # Simple modification
mock_loss = 1.0 / (epoch + 1)
print(f"Epoch {epoch+1}, Mock Loss: {mock_loss:.4f}")
# Save a checkpoint every 2 epochs
if (epoch + 1) % 2 == 0:
checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model_for_checkpointing.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': mock_loss,
}
torch.save(checkpoint, CHECKPOINT_PATH)
print(f"Checkpoint saved at epoch {epoch + 1} to {CHECKPOINT_PATH}")
print("\nSimulated training finished.")
print("Parameters of model after simulated training (sum):")
print_model_parameters(model_for_checkpointing)
Now, let's see how to resume from this checkpoint.
# To resume, create new instances of model and optimizer
resumed_model = SimpleNet()
resumed_optimizer = optim.Adam(resumed_model.parameters(), lr=0.001) # Ensure optimizer params are consistent
print("\nParameters of new model before loading checkpoint (sum):")
print_model_parameters(resumed_model)
# Load the checkpoint
if os.path.exists(CHECKPOINT_PATH):
checkpoint = torch.load(CHECKPOINT_PATH)
resumed_model.load_state_dict(checkpoint['model_state_dict'])
resumed_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
last_loss = checkpoint['loss']
print(f"\nCheckpoint loaded. Resuming from epoch {start_epoch}.")
print("Parameters of model after loading checkpoint (sum):")
print_model_parameters(resumed_model)
# Set model to training mode if you plan to continue training
resumed_model.train()
# Or evaluation mode if for inference
# resumed_model.eval()
else:
print("\nNo checkpoint found to resume from.")
Saving the optimizer's state_dict
is important because it contains buffers and parameters (like learning rates or momentum values) that are updated during training.
TorchScript provides a way to create serializable and optimizable representations of your PyTorch models that can run independently of Python, for instance, in C++ environments or in settings where Python overhead is undesirable. There are two primary ways to convert a PyTorch model to TorchScript: tracing and scripting.
torch.jit.trace
): You provide an example input, and TorchScript records the operations performed on this input as it passes through the model. This works well for models with straightforward, data-independent control flow.torch.jit.script
): TorchScript directly analyzes your model's Python source code (including control flow like if
statements and loops) and translates it into the TorchScript intermediate representation. This is more suitable for models with complex control flow.For this practical, we'll focus on tracing, which is often simpler to get started with.
Tracing a Model
# Instantiate a model for tracing
model_to_trace = SimpleNet()
model_to_trace.eval() # Important: set model to evaluation mode for tracing
# Create an example input tensor with the correct shape
# Our SimpleNet expects input_size=10
example_input = torch.randn(1, 10) # Batch size 1, 10 features
# Trace the model
try:
traced_model = torch.jit.trace(model_to_trace, example_input)
print("\nModel successfully traced.")
# You can inspect the traced model's graph (optional)
# print(traced_model.graph)
# And its code (optional)
# print(traced_model.code)
except Exception as e:
print(f"\nError during tracing: {e}")
traced_model = None
The traced_model
is now a torch.jit.ScriptModule
object. It has captured the sequence of operations that model_to_trace
performed when given example_input
.
Saving a Traced Model
Traced models have their own save method.
TRACED_MODEL_PATH = os.path.join(MODEL_SAVE_DIR, "simplenet_traced.pt") # .pt is common for TorchScript
if traced_model:
traced_model.save(TRACED_MODEL_PATH)
print(f"Traced model saved to {TRACED_MODEL_PATH}")
Loading and Using a Traced Model
You can load a TorchScript model using torch.jit.load()
. A significant advantage is that you don't need the original Python model class definition (SimpleNet
in this case) to load and run the traced model.
if os.path.exists(TRACED_MODEL_PATH):
loaded_traced_model = torch.jit.load(TRACED_MODEL_PATH)
print("\nTraced model loaded successfully.")
# You can now use the loaded traced model for inference
# Ensure the input tensor has the correct shape and type
test_input = torch.randn(1, 10)
with torch.no_grad(): # Always good practice for inference
output = loaded_traced_model(test_input)
print(f"Output from loaded traced model: {output.numpy()}")
# Verify output with original model (optional, if available)
# original_output = model_to_trace(test_input)
# print(f"Output from original Python model: {original_output.detach().numpy()}")
else:
print("\nTraced model file not found.")
This ability to run without the original Python code makes TorchScript models highly portable and suitable for various deployment scenarios.
These exercises have covered the fundamental techniques for model persistence and an introduction to TorchScript. As you work with more complex models and deployment requirements, you'll build upon these foundations. For instance, you might explore TorchScript scripting for models with dynamic control flow or look into ONNX for interoperability with other machine learning frameworks. Experiment with these methods on your own models to solidify your understanding.
© 2025 ApX Machine Learning