Training a model with model.fit()
is a powerful but often lengthy process. During this time, you might want to perform actions based on the training progress, such as saving your model periodically, stopping early if performance plateaus, or adjusting the learning rate. Keras provides a clean mechanism for this through callbacks.
Callbacks are objects that you can pass to the model.fit()
method. They are called by the framework at specific points during the training process (e.g., at the beginning or end of an epoch, before or after processing a batch). This allows you to automate tasks that would otherwise require manual intervention or complex custom training loops.
Let's explore some of the most useful built-in callbacks provided by Keras.
Training deep learning models can take hours or even days. It's essential to save your model's state, especially if the training process might be interrupted. Furthermore, you often want to save the version of the model that performed best on the validation set, not necessarily the one from the very last epoch, as models can start to overfit.
The ModelCheckpoint
callback handles this. You can configure it to save the model (or just its weights) periodically or whenever its performance on a monitored metric improves.
import keras
# Define the callback
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath='best_model_epoch_{epoch:02d}_val_loss_{val_loss:.2f}.keras', # Path where to save the model
monitor='val_loss', # Metric to monitor
save_best_only=True, # Only save if `monitor` shows improvement
save_weights_only=False, # Save the entire model (architecture + weights + optimizer state)
mode='min', # 'min' because we want to minimize val_loss
verbose=1 # Print messages when saving
)
# Later, when training:
# history = model.fit(..., callbacks=[model_checkpoint_callback])
Key Parameters:
filepath
: A string defining the path and filename for the saved model. You can include formatting options like {epoch:02d}
to insert the epoch number (padded with a zero if needed) or {val_loss:.2f}
to include the value of the monitored metric (formatted to two decimal places). Common file extensions are .keras
(preferred format for saving the entire model) or .weights.h5
if save_weights_only=True
.monitor
: The metric to watch (e.g., 'val_loss'
, 'val_accuracy'
). This metric must be available from the model's compilation step or evaluated on the validation data provided to fit()
.save_best_only
: If True
, the callback only saves the model when the monitor
metric improves compared to the previous best value. If False
, it saves the model at the end of every epoch.save_weights_only
: If True
, only the model's weights are saved. If False
, the entire model (architecture, weights, and optimizer state) is saved, allowing you to fully resume training later.mode
: Determines whether improvement means minimizing ('min'
) or maximizing ('max'
) the monitored metric. It can often be inferred automatically based on the metric name (e.g., loss
implies 'min'
, accuracy
implies 'max'
), but it's good practice to set it explicitly.verbose
: Set to 1
to see messages when the callback saves the model.Using ModelCheckpoint
with save_best_only=True
ensures you retain the model state that achieved the best validation performance, guarding against overfitting in later epochs.
As discussed in the section on overfitting, a model might reach optimal performance on the validation set and then start to perform worse as training continues. Training beyond this point wastes computational resources and can lead to a less generalizable model.
The EarlyStopping
callback monitors a specified metric and stops the training process automatically if the metric stops improving for a defined number of epochs.
import keras
# Define the callback
early_stopping_callback = keras.callbacks.EarlyStopping(
monitor='val_loss', # Metric to monitor
patience=10, # Number of epochs with no improvement after which training will be stopped
min_delta=0.001, # Minimum change to qualify as an improvement
mode='min', # 'min' because we want to minimize val_loss
restore_best_weights=True, # Restore model weights from the epoch with the best value of the monitored quantity.
verbose=1 # Print messages when stopping
)
# Later, when training:
# history = model.fit(..., callbacks=[early_stopping_callback])
Key Parameters:
monitor
: The metric to watch, typically a validation metric like 'val_loss'
or 'val_accuracy'
.patience
: This is the number of consecutive epochs to wait without improvement before stopping the training. For example, if patience=10
, training will stop if the monitored metric hasn't improved for 10 epochs in a row.min_delta
: The minimum change in the monitored quantity to qualify as an improvement. This helps prevent stopping due to tiny, insignificant fluctuations. For example, if monitor='val_loss'
and min_delta=0.001
, a decrease of less than 0.001 will not be considered an improvement.mode
: As with ModelCheckpoint
, specify 'min'
or 'max'
depending on the metric being monitored.restore_best_weights
: If True
, the model's weights will be restored to those from the epoch that achieved the best value of the monitored metric. This is highly recommended, as it ensures your final model state corresponds to the best performance observed during training, even if training stopped several epochs later due to the patience
setting.verbose
: Set to 1
to see a message when training is stopped by the callback.EarlyStopping
is a valuable tool for both preventing overfitting and saving computation time by stopping training once the model ceases to learn effectively on the validation data.
Sometimes, during training, the model's improvement on the validation metric might slow down and plateau. The current learning rate might be too large to allow for fine-tuning the weights into a better minimum of the loss function. In such cases, reducing the learning rate can help the optimizer take smaller steps and potentially escape the plateau.
The ReduceLROnPlateau
callback monitors a metric and reduces the learning rate by a specified factor if no improvement is seen for a given number of epochs (patience
).
import keras
# Define the callback
reduce_lr_callback = keras.callbacks.ReduceLROnPlateau(
monitor='val_loss', # Metric to monitor
factor=0.2, # Factor by which the learning rate will be reduced. new_lr = lr * factor
patience=5, # Number of epochs with no improvement after which learning rate will be reduced.
min_lr=0.00001, # Lower bound on the learning rate.
mode='min', # 'min' because we want to minimize val_loss
verbose=1 # Print messages when reducing LR
)
# Later, when training:
# history = model.fit(..., callbacks=[reduce_lr_callback])
Key Parameters:
monitor
: The metric whose stagnation triggers the learning rate reduction.factor
: The factor by which the learning rate will be reduced (e.g., 0.2
means the new learning rate will be current_lr * 0.2
). Must be less than 1.0.patience
: The number of epochs with no improvement after which the learning rate will be reduced. Note that this patience
count resets after each learning rate reduction.min_lr
: A lower bound on the learning rate. The callback will not reduce the learning rate below this value.min_delta
: Threshold for measuring the new optimum, to only focus on significant changes.mode
: 'min'
or 'max'
.verbose
: Set to 1
to see messages when the learning rate is reduced.This callback allows for a dynamic learning rate schedule, adapting to the training dynamics without requiring manual intervention.
You rarely use just one callback. It's common practice to combine several callbacks to manage different aspects of the training process. To use multiple callbacks, you simply pass a list of callback objects to the callbacks
argument in model.fit()
.
import keras
# Instantiate all desired callbacks
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath='best_model.keras',
monitor='val_loss',
save_best_only=True,
mode='min'
)
early_stopping_callback = keras.callbacks.EarlyStopping(
monitor='val_loss',
patience=15,
mode='min',
restore_best_weights=True
)
reduce_lr_callback = keras.callbacks.ReduceLROnPlateau(
monitor='val_loss',
factor=0.1,
patience=7,
min_lr=0.00001,
mode='min'
)
# Create a list of callbacks
callback_list = [
model_checkpoint_callback,
early_stopping_callback,
reduce_lr_callback
]
# Pass the list to model.fit()
# history = model.fit(
# x_train, y_train,
# epochs=100,
# batch_size=32,
# validation_data=(x_val, y_val),
# callbacks=callback_list
# )
In this example, the training process will:
val_loss
.best_model.keras
whenever val_loss
improves (ModelCheckpoint
).val_loss
doesn't improve for 7 consecutive epochs (ReduceLROnPlateau
).val_loss
doesn't improve for 15 consecutive epochs, and restore the weights from the best epoch (EarlyStopping
).Keras includes other built-in callbacks, such as TensorBoard
(which we cover in detail in the "Introduction to TensorBoard" section) for logging metrics and visualizing graphs, and CSVLogger
for streaming epoch results to a CSV file.
Furthermore, Keras allows you to create your own custom callbacks by inheriting from the keras.callbacks.Callback
base class and overriding methods like on_epoch_end
, on_batch_begin
, etc. This provides great flexibility for implementing specialized behaviors during training, although the built-in callbacks cover most common use cases.
By effectively using callbacks like ModelCheckpoint
, EarlyStopping
, and ReduceLROnPlateau
, you can significantly improve your training workflow, making it more robust, efficient, and less prone to overfitting. They are indispensable tools for practical deep learning development.
© 2025 ApX Machine Learning