Training deep learning models can often be a time-consuming process, sometimes taking hours, days, or even weeks depending on the complexity of the model and the size of the dataset. During such long training runs, various issues can arise: power outages, system crashes, or simply the need to pause and resume later. Without a mechanism to save progress, you risk losing valuable computation time and potentially the best version of your model achieved during training.
This is where saving checkpoints becomes indispensable. Checkpointing involves saving the model's state at regular intervals or when specific conditions are met during the training process. TensorFlow's Keras API provides a convenient way to implement this using callbacks.
ModelCheckpoint
CallbackThe primary tool for saving checkpoints in Keras is the tf.keras.callbacks.ModelCheckpoint
callback. A callback is an object that can perform actions at various stages of training (e.g., at the start or end of an epoch, before or after processing a batch). The ModelCheckpoint
callback specifically monitors the training process and saves the model based on configured criteria.
You integrate this callback by creating an instance of it and passing it to the callbacks
list argument in the model.fit()
method.
Let's look at its main configuration options:
filepath
: This is the path where the checkpoint file(s) will be saved. You can include formatting options in the filename to make it unique for each save, incorporating the epoch number and values of monitored metrics. For example:
'model_checkpoint.weights.h5'
: Saves to a single file (overwritten each time unless save_best_only=True
).'checkpoints/epoch_{epoch:02d}-val_loss_{val_loss:.2f}.weights.h5'
: Creates files like epoch_01-val_loss_0.54.weights.h5
, epoch_02-val_loss_0.51.weights.h5
, etc. This saves checkpoints from multiple epochs.monitor
: Specifies the metric to monitor. Common choices include 'val_loss'
(validation loss) or 'val_accuracy'
(validation accuracy). The callback will use this metric's value to decide whether the current model is "better" than the previous best. If not specified, the callback operates without considering performance metrics (e.g., saving every epoch regardless of performance).save_best_only
: If True
, the callback only saves a checkpoint when the monitored metric shows improvement compared to the best value seen so far in training. This is extremely useful for keeping only the single best performing model checkpoint. If False
, it saves the model at the end of every period defined by save_freq
.save_weights_only
: If True
, only the model's weights (the values of its learnable parameters) are saved. This results in smaller checkpoint files. If False
, the entire model is saved, including its architecture, weights, and the optimizer's state. Saving the entire model allows you to recreate the model and resume training exactly where you left off.mode
: Determines whether improvement means minimizing or maximizing the monitored metric. Options are 'min'
, 'max'
, or 'auto'
. If monitor
is set to 'val_loss'
, 'auto'
will correctly infer 'min'
. If set to 'val_accuracy'
, 'auto'
will infer 'max'
. Explicitly setting it can avoid ambiguity.save_freq
: Defines how often checkpoints are saved.
'epoch'
(default): Saves at the end of each epoch.1000
): Saves after every specified number of batches.Let's illustrate how to use ModelCheckpoint
to save only the weights of the best model observed so far, based on validation loss.
import tensorflow as tf
import numpy as np
# Assume 'model' is a compiled Keras model
# Assume 'x_train', 'y_train', 'x_val', 'y_val' are your training and validation data
# Define the path for saving checkpoints
checkpoint_filepath = 'best_model.weights.h5'
# Create the ModelCheckpoint callback
# Monitor validation loss, save only the best weights
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_loss',
mode='min',
save_best_only=True)
# Train the model, including the callback
print("Starting training with checkpoint callback...")
history = model.fit(x_train, y_train,
epochs=50,
batch_size=32,
validation_data=(x_val, y_val),
callbacks=[model_checkpoint_callback]) # Pass callback to training
print(f"Training finished. Best model weights saved to {checkpoint_filepath}")
# Later, you can load these weights into a model with the same architecture
# model.load_weights(checkpoint_filepath)
# print("Model weights loaded from checkpoint.")
In this example:
checkpoint_filepath
where the best weights will be stored.ModelCheckpoint
instance.
save_weights_only=True
: We only save the weights.monitor='val_loss'
: We watch the validation loss.mode='min'
: Improvement means the validation loss decreases.save_best_only=True
: Only the checkpoint corresponding to the lowest val_loss
seen so far will be saved/overwritten.callbacks
argument of model.fit()
.During training, Keras will evaluate the validation loss at the end of each epoch. If the loss has improved (decreased) compared to all previous epochs, the current model weights are saved to best_model.weights.h5
, overwriting any previous version. If the loss does not improve, no file is saved for that epoch. At the end of training, best_model.weights.h5
will contain the weights from the epoch that achieved the lowest validation loss.
Flow diagram illustrating the
ModelCheckpoint
behavior withsave_best_only=True
andmonitor='val_loss'
. Checkpoints are saved only when the validation loss improves.
If you set save_weights_only=False
(the default), Keras saves the entire model:
model.compile()
.This is saved in TensorFlow's SavedModel
format (if the filepath doesn't end in .h5
) or the older Keras HDF5 format (if the filepath ends in .h5
). The SavedModel
format is generally recommended.
# Example for saving the entire model periodically
checkpoint_filepath_full = 'checkpoints/model_epoch_{epoch:02d}.keras' # Use .keras for SavedModel format
full_model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath_full,
save_weights_only=False, # Save entire model
monitor='val_accuracy', # Monitor validation accuracy
mode='max', # Improvement means accuracy increases
save_best_only=False, # Save at the end of every epoch
save_freq='epoch' # Explicitly state saving frequency
)
# Train the model with this callback
# history = model.fit(..., callbacks=[full_model_checkpoint_callback])
This configuration saves the complete model at the end of every epoch into a directory structure (if using SavedModel format like .keras
) or a single file (if using .h5
), named according to the epoch number. This uses more disk space but captures everything needed to restore the model or resume training.
Choosing between saving only weights or the entire model depends on your needs. If you only need the learned parameters for inference or fine-tuning, saving weights is sufficient and more efficient. If you need to resume training or deploy the model with its full configuration, saving the entire model is preferable.
Using checkpoints effectively ensures that your training efforts are preserved, allowing you to recover from interruptions and retain the best performing versions of your models. The next sections will cover how to load these saved weights and models.
© 2025 ApX Machine Learning