Sometimes, you don't need to save the entire model structure, especially if your model architecture is already defined in your code. You might only be interested in preserving the learned parameters, the weights and biases, which represent the knowledge captured during training. Saving only the weights is also useful if you want to load these parameters into a different but architecturally compatible model, perhaps for transfer learning or fine-tuning. It generally results in smaller file sizes compared to saving the entire model.
TensorFlow, through Keras, provides a straightforward way to save and load just the model's weights.
model.save_weights
The model.save_weights(filepath)
method allows you to save the current values of all variables (weights and biases) associated with the model.
The filepath
argument specifies the location and filename for the weights. TensorFlow supports two main formats for saving weights:
.weights.ckpt
suffix if you provide a path like my_model_weights.weights
). It saves weights in a TensorFlow-native format. When using this format, TensorFlow might create multiple files (e.g., .weights.ckpt.index
, .weights.ckpt.data-00000-of-00001
). You typically only need to provide the base path (my_model_weights.weights.ckpt
) when loading. This format is generally recommended as it integrates well with the TensorFlow ecosystem..weights.h5
). You can explicitly request this format using save_format='h5'
or by ending the filepath with .h5
or .keras
. While widely used previously, the TensorFlow Checkpoint format is often preferred for better compatibility within TensorFlow tools.Let's assume you have trained a simple model:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
# Define a simple Sequential model
def build_simple_model():
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(784,)),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
model = build_simple_model()
# Generate dummy data
(x_train, y_train), _ = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255.0
y_train = y_train.astype('int32')
# Train the model (briefly for demonstration)
print("Training model...")
model.fit(x_train[:1000], y_train[:1000], epochs=1, batch_size=32, verbose=0)
print("Training complete.")
# Save weights in TensorFlow Checkpoint format (default)
print("Saving weights in TF Checkpoint format...")
model.save_weights('my_model_weights.weights.ckpt')
print("Weights saved.")
# Save weights explicitly in HDF5 format
print("Saving weights in HDF5 format...")
model.save_weights('my_model_weights.weights.h5', save_format='h5')
print("Weights saved.")
After running this code, you will find files related to my_model_weights.weights.ckpt
(like .index
and .data-...
files) and a single my_model_weights.weights.h5
file in your directory.
model.load_weights
To load saved weights, you first need an instance of a model with the exact same architecture as the model from which the weights were saved. This includes the same layers, in the same order, with the same configurations (number of units, activation functions, etc.). If the architectures don't match, TensorFlow won't know how to map the saved weights to the model's layers and will raise an error.
Once you have a compatible model instance, you can load the weights using model.load_weights(filepath)
. The filepath
should point to the same path prefix (for TF Checkpoint) or filename (for HDF5) used during saving.
# Build a new instance of the same model architecture
new_model = build_simple_model()
# Verify performance before loading weights (should be random)
print("\nEvaluating new model before loading weights:")
loss, acc = new_model.evaluate(x_train[:1000], y_train[:1000], verbose=0)
print(f"Untrained model accuracy: {acc:.4f}")
# Load the previously saved weights (from TF Checkpoint format)
print("Loading weights from TF Checkpoint file...")
new_model.load_weights('my_model_weights.weights.ckpt')
print("Weights loaded.")
# Verify performance after loading weights (should match the trained model)
print("\nEvaluating new model after loading weights:")
loss, acc = new_model.evaluate(x_train[:1000], y_train[:1000], verbose=0)
print(f"Model accuracy after loading weights: {acc:.4f}")
# You can also load from the HDF5 file
# new_model.load_weights('my_model_weights.weights.h5')
The output will show that the new_model
initially performs poorly (like an untrained model) but achieves the same accuracy as the originally trained model after loading the weights.
It's important to reiterate: model.load_weights()
requires that the model receiving the weights has a compatible structure. You cannot load weights saved from a convolutional neural network into a simple dense network, for instance. TensorFlow matches weights based on layer order and naming. This method assumes the model's architecture is defined elsewhere (like in your Python script) and only transfers the learned parameters.
This approach is particularly convenient when you are iterating on your training process or when you deploy your application, where the model structure is defined in code, and you just need to load the trained parameters. It's also the underlying mechanism used by the ModelCheckpoint
callback when configured with save_weights_only=True
, allowing you to automatically save the best-performing weights during a long training run without saving the entire model structure repeatedly.
© 2025 ApX Machine Learning