Training a model using model.fit()
can take significant time, especially with large datasets or complex architectures running for many epochs. During this process, you might want to perform actions automatically at specific points, like saving your model periodically, stopping training early if performance plateaus, or adjusting the learning rate. This is where Keras Callbacks come in.
Callbacks are objects that you can pass to model.fit()
(in the callbacks
argument list) which perform predefined actions at various stages of training (e.g., at the start or end of an epoch, before or after a single batch). They provide a powerful mechanism to customize and control the training loop without modifying the core fit
method.
Let's look at some of the most commonly used callbacks provided by tf.keras.callbacks
.
Imagine training a model for 100 epochs. Perhaps the best performance on the validation set occurred at epoch 75, but by epoch 100, the model started overfitting, and the validation performance degraded. If you only saved the model at the very end, you would miss the best version!
The ModelCheckpoint
callback addresses this by periodically saving the model during training. You can configure it to save only the model's weights or the entire model (architecture, weights, and optimizer state). Critically, you can instruct it to only save the best model observed so far, based on a monitored metric like validation loss or accuracy.
Here's how you might configure it to save the best model based on validation loss:
import tensorflow as tf
# Assume 'model' is a compiled Keras model
# Define the callback
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath='best_model_{epoch:02d}_{val_loss:.2f}.keras', # File path with formatting options
save_weights_only=False, # Save the entire model
monitor='val_loss', # Metric to monitor
mode='min', # We want to minimize loss
save_best_only=True) # Only save if 'val_loss' improves
# Now, pass it to model.fit() (assuming you have training and validation data)
# history = model.fit(train_dataset, epochs=100, validation_data=validation_dataset,
# callbacks=[model_checkpoint_callback])
Key arguments for ModelCheckpoint
:
filepath
: Path to save the model file. You can use formatting options like {epoch:02d}
to include the epoch number (padded with zeros) or {val_loss:.2f}
to include the monitored metric value (formatted to 2 decimal places). The recommended format for saving the entire model is .keras
. Use .weights.h5
if save_weights_only=True
.monitor
: The quantity to monitor (e.g., 'val_loss'
, 'val_accuracy'
). Training metrics (e.g. 'loss'
, 'accuracy'
) can also be monitored, but monitoring validation metrics is generally preferred for selecting the best model for generalization.mode
: One of {'auto', 'min', 'max'}
. If monitor
is val_loss
, mode should be 'min'
. If monitor
is val_accuracy
, mode should be 'max'
. In 'auto'
mode, Keras infers the direction based on the metric name.save_best_only
: If True
, it only saves the model when the monitored quantity improves compared to the best value seen so far.save_weights_only
: If True
, only the model's weights are saved (model.save_weights()
). If False
, the entire model is saved (model.save()
), including architecture and optimizer state.Training for too long can lead to overfitting, where the model performs well on the training data but poorly on unseen validation data. The EarlyStopping
callback monitors a specified metric (usually a validation metric) and stops the training process if the metric stops improving for a defined number of epochs.
# Define the callback
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
monitor='val_loss', # Metric to monitor
patience=10, # Number of epochs with no improvement after which training will be stopped
mode='min', # We want to minimize loss
restore_best_weights=True # Restore model weights from the epoch with the best 'val_loss'
)
# Pass it to model.fit(), potentially along with ModelCheckpoint
# history = model.fit(train_dataset, epochs=100, validation_data=validation_dataset,
# callbacks=[model_checkpoint_callback, early_stopping_callback])
Key arguments for EarlyStopping
:
monitor
: Quantity to be monitored (e.g., 'val_loss'
).patience
: Number of epochs with no improvement after which training will be stopped. For example, if patience=10
, training will halt if the monitored metric hasn't improved for 10 consecutive epochs.min_delta
: Minimum change in the monitored quantity to qualify as an improvement. Defaults to 0. Setting a small positive value can prevent stopping due to negligible improvements.mode
: One of {'auto', 'min', 'max'}
. Determines whether improvement means a decrease ('min'
) or increase ('max'
).restore_best_weights
: If True
, after training stops, the model weights are rolled back to those from the epoch that achieved the best value of the monitored quantity. This is highly recommended, as the training might stop a few epochs after the best performance was observed.While ModelCheckpoint
saves your model and EarlyStopping
controls the training duration, the TensorBoard
callback logs metrics, graphs, and other information during training. This logged data can then be visualized using the TensorBoard tool (which we will explore in the next section) to gain insights into the training process, diagnose issues, and compare different runs.
import datetime
# Define the callback
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir,
histogram_freq=1 # Log histogram visualizations every 1 epoch
)
# Pass it during fit
# history = model.fit(train_dataset, epochs=100, validation_data=validation_dataset,
# callbacks=[model_checkpoint_callback,
# early_stopping_callback,
# tensorboard_callback])
The primary argument here is log_dir
, specifying the directory where TensorBoard logs will be written. Using a timestamp in the directory name helps keep logs from different runs separate. We will cover how to use these logs in the next section.
You can use multiple callbacks simultaneously by passing them as a list to the callbacks
argument in model.fit()
. Keras will execute each callback at the appropriate stage of the training loop.
# Example using all three discussed callbacks
callbacks_list = [
tf.keras.callbacks.ModelCheckpoint(
filepath='best_model.keras', monitor='val_loss', save_best_only=True
),
tf.keras.callbacks.EarlyStopping(
monitor='val_loss', patience=15, restore_best_weights=True
),
tf.keras.callbacks.TensorBoard(
log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
)
]
# history = model.fit(train_dataset, epochs=100, validation_data=validation_dataset,
# callbacks=callbacks_list)
Other useful callbacks exist, such as LearningRateScheduler
for custom learning rate schedules or ReduceLROnPlateau
to decrease the learning rate when a metric stops improving. Callbacks provide a flexible way to add custom behavior and control to your Keras training workflows, helping you manage long training runs, prevent overfitting, and save optimal model states.
© 2025 ApX Machine Learning