Once you've saved your model's weights or its entire state, the next logical step is to load them back. This is essential for several common workflows: using a trained model for predictions (inference), fine-tuning a model on a new dataset, or simply resuming a training process that was interrupted. TensorFlow and Keras provide straightforward functions to handle these scenarios.
Often, you might have the Python code that defines your model's architecture but need to load previously trained weights into it. This is common when you've used model.save_weights()
or the ModelCheckpoint
callback configured to save weights only (e.g., save_weights_only=True
).
The primary method for this is tf.keras.Model.load_weights()
. It takes the path to the saved weights file as its main argument.
Important Consideration: For load_weights()
to succeed, the model architecture defined in your code must exactly match the architecture of the model from which the weights were saved. This includes the number of layers, the type of layers, the order of layers, and the configuration (like units, activation functions, etc.) of each layer. If the architectures don't match, TensorFlow will typically raise an error because it won't know how to map the saved weight tensors to the layers in your current model.
Let's assume you have a simple Sequential model defined and weights saved in a file named 'my_model_weights.weights.h5'
.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# 1. Define the *exact* same model architecture
def build_simple_model():
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(784,)),
layers.Dropout(0.2),
layers.Dense(10, activation='softmax')
])
# Note: You typically compile the model *after* loading weights
# if you are only doing inference. If resuming training, compile before.
return model
# Create an instance of the model
model = build_simple_model()
# Print summary to see architecture (optional)
# model.summary()
# Path to the saved weights file
weights_filepath = 'my_model_weights.weights.h5'
# 2. Load the weights
try:
model.load_weights(weights_filepath)
print(f"Successfully loaded weights from {weights_filepath}")
# Now the model instance has the trained parameters
# You can proceed with evaluation or prediction
# model.compile(...) # Compile if needed for evaluation/training
# loss, acc = model.evaluate(x_test, y_test, verbose=0)
# print(f"Restored model accuracy: {acc:.2f}")
except Exception as e:
print(f"Error loading weights: {e}")
print("Ensure the model architecture matches the saved weights.")
This approach is flexible because it separates the model definition (code) from the learned parameters (weights file).
If you saved the entire model using model.save(filepath)
, you saved not just the weights but also the model's architecture and potentially its optimizer state. This is often stored in TensorFlow's SavedModel format (a directory) or a single HDF5 file (.h5
or .keras
).
To load such a complete model, you use the top-level function tf.keras.models.load_model()
.
import tensorflow as tf
# Path to the saved model directory or file
saved_model_path = 'my_full_model_savedmodel' # Or 'my_full_model.h5' / 'my_full_model.keras'
try:
# Load the entire model
loaded_model = tf.keras.models.load_model(saved_model_path)
print(f"Successfully loaded model from {saved_model_path}")
# The loaded_model object is a compiled Keras model, ready to use.
loaded_model.summary()
# You can immediately use it for evaluation or prediction
# loss, acc = loaded_model.evaluate(x_test, y_test, verbose=0)
# print(f"Loaded model accuracy: {acc:.2f}")
# Or even continue training if the optimizer state was saved
# history = loaded_model.fit(x_train, y_train, epochs=5, ...)
except Exception as e:
print(f"Error loading model: {e}")
The significant advantage here is that you don't need the original Python code that defined the model architecture. load_model
reconstructs the model based on the information stored in the saved file or directory. It also restores the training configuration (optimizer, loss, metrics) if it was saved, making it easy to resume training or use the exact settings intended by the model's creator.
Handling Custom Objects: If your saved model includes custom layers, custom loss functions, or custom activation functions that aren't part of the standard TensorFlow/Keras library, load_model
might fail because it doesn't know how to interpret these custom components. To handle this, you can pass a custom_objects
dictionary to load_model
, mapping the saved names of your custom objects to their corresponding Python classes or functions.
# Assume you have a custom layer class defined: MyCustomLayer
# When loading:
# loaded_model = tf.keras.models.load_model(
# saved_model_path,
# custom_objects={'MyCustomLayer': MyCustomLayer}
# )
Alternatively, you can use tf.keras.utils.register_keras_serializable
to register your custom objects globally.
load_weights
vs. load_model
model.load_weights(filepath)
when:
tf.keras.models.load_model(filepath)
when:
Successfully loading a pre-trained model allows you to build upon previous work, saving significant time and computational resources. It's a fundamental skill for applying deep learning effectively. As we'll see briefly in the next section, platforms like TensorFlow Hub further streamline the process of accessing and reusing models trained by the wider community.
© 2025 ApX Machine Learning