Let's apply the concepts covered in this chapter: compiling a model, training it with model.fit()
, evaluating its performance, making predictions, and using callbacks for better training management and monitoring. We'll build and train a simple classifier on a standard dataset like Fashion-MNIST, integrating ModelCheckpoint
, EarlyStopping
, and TensorBoard
.
This practice assumes you have TensorFlow installed and are familiar with basic NumPy operations. We'll use the Fashion-MNIST dataset, which is conveniently available through tf.keras.datasets
.
First, load the Fashion-MNIST dataset and preprocess it. Preprocessing typically involves normalization (scaling pixel values) and reshaping the data if necessary for the model's input layer. We also need to one-hot encode the labels for categorical crossentropy loss.
import tensorflow as tf
import numpy as np
import os
import datetime
# Load Fashion MNIST dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
# Normalize pixel values to be between 0 and 1
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# Reshape input data if using a Dense network first (add channel dimension for Conv later)
# For Dense layers, we flatten the images
x_train_flat = x_train.reshape((x_train.shape[0], 28 * 28))
x_test_flat = x_test.reshape((x_test.shape[0], 28 * 28))
# One-hot encode the labels
num_classes = 10
y_train_cat = tf.keras.utils.to_categorical(y_train, num_classes)
y_test_cat = tf.keras.utils.to_categorical(y_test, num_classes)
print(f"x_train shape: {x_train_flat.shape}") # Shape: (60000, 784)
print(f"y_train shape: {y_train_cat.shape}") # Shape: (60000, 10)
print(f"x_test shape: {x_test_flat.shape}") # Shape: (10000, 784)
print(f"y_test shape: {y_test_cat.shape}") # Shape: (10000, 10)
We'll define a straightforward sequential model with a few dense layers.
def build_simple_model(input_shape, num_classes):
model = tf.keras.models.Sequential([
tf.keras.layers.Input(shape=input_shape), # Use Input layer for clarity
tf.keras.layers.Dense(128, activation='relu', name='dense_1'),
tf.keras.layers.Dropout(0.3, name='dropout_1'), # Add dropout for regularization
tf.keras.layers.Dense(64, activation='relu', name='dense_2'),
tf.keras.layers.Dense(num_classes, activation='softmax', name='output')
])
return model
input_shape = (28 * 28,) # Flattened image shape
model = build_simple_model(input_shape, num_classes)
model.summary()
Next, compile the model, specifying the optimizer, loss function, and metrics to track during training.
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=['accuracy'])
print("Model compiled.")
Now, let's set up the callbacks we discussed:
ModelCheckpoint
: To save the best model weights observed during training based on validation accuracy.EarlyStopping
: To halt training if the validation loss stops improving, preventing overfitting.TensorBoard
: To log metrics and graph structure for visualization.# Define log directory for TensorBoard (unique for each run)
log_dir = os.path.join("logs", "fit", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
print(f"TensorBoard log directory: {log_dir}")
# Define checkpoint path and filename
checkpoint_filepath = 'models/best_fashion_mnist_model.weights.h5'
os.makedirs(os.path.dirname(checkpoint_filepath), exist_ok=True)
print(f"Model checkpoints will be saved to: {checkpoint_filepath}")
# Create callbacks
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True, # Save only weights
monitor='val_accuracy', # Monitor validation accuracy
mode='max', # Save the model with max validation accuracy
save_best_only=True) # Only save if it's the 'best' so far
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
monitor='val_loss', # Monitor validation loss
patience=10, # Number of epochs with no improvement after which training will be stopped
verbose=1, # Print messages when stopping
restore_best_weights=True) # Restore model weights from the epoch with the best value of the monitored quantity.
callbacks_list = [tensorboard_callback, model_checkpoint_callback, early_stopping_callback]
Note: restore_best_weights=True
in EarlyStopping
is convenient as the model object will hold the best weights automatically when training stops. If set to False
, you'd typically load the weights saved by ModelCheckpoint
manually after training.
We are now ready to train the model using model.fit()
. We'll pass the training data, specify the number of epochs, batch size, validation data (using a split of the training set), and our list of callbacks.
print("Starting model training...")
batch_size = 64
epochs = 50 # Set a higher number, EarlyStopping will likely stop it sooner
history = model.fit(x_train_flat, y_train_cat,
epochs=epochs,
batch_size=batch_size,
validation_split=0.2, # Use 20% of training data for validation
callbacks=callbacks_list,
verbose=1) # Set verbose=1 or 2 to see progress per epoch
print("Model training finished.")
During training, you'll see output for each epoch, including loss and accuracy for both the training and validation sets. Pay attention to messages from EarlyStopping
if it halts the training run. ModelCheckpoint
will silently save the best weights whenever validation accuracy improves.
While the model is training (or after it finishes), you can launch TensorBoard to visualize the metrics. Open your terminal or command prompt, navigate to the directory containing the logs
folder (or provide the full path), and run:
tensorboard --logdir logs/fit
TensorBoard will typically start a web server on http://localhost:6006
. Open this URL in your browser. You should see:
histogram_freq
was set).Hypothetical training and validation loss curves as might be seen in TensorBoard. Note how validation loss starts to plateau or increase while training loss continues to decrease, indicating potential overfitting. Early stopping would halt training around epoch 10-12.
After training, evaluate the model's performance on the unseen test set using model.evaluate()
. Since we used restore_best_weights=True
in EarlyStopping
, the model
object already contains the weights from the epoch with the best validation loss. If you hadn't used that option or wanted to load specifically from the ModelCheckpoint
file, you would first load the weights:
# Optional: Load the best weights saved by ModelCheckpoint
# model.load_weights(checkpoint_filepath)
# print("Loaded best weights from checkpoint.")
print("Evaluating model on test data...")
test_loss, test_acc = model.evaluate(x_test_flat, y_test_cat, verbose=0)
print(f"\nTest Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
This gives you the final performance metrics on data the model never saw during training or validation.
Finally, use model.predict()
to get predictions on new data (we'll use a few examples from the test set here). The output of a softmax layer gives probabilities for each class.
# Get predictions for the first 5 test images
predictions = model.predict(x_test_flat[:5])
# Print the predicted class probabilities for each image
print("\nPredictions (Probabilities) for first 5 test images:")
print(predictions)
# Get the class with the highest probability for each image
predicted_classes = np.argmax(predictions, axis=1)
print("\nPredicted Classes:", predicted_classes)
# Get the actual classes for comparison
actual_classes = y_test[:5]
print("Actual Classes: ", actual_classes)
The output shows the probability distribution across the 10 classes for each of the 5 input images, followed by the index (class label) with the highest probability.
This walkthrough demonstrates the standard workflow for training a Keras model in TensorFlow. You compiled the model with necessary components, used model.fit
along with callbacks to manage the training process efficiently (saving the best model, stopping early, logging for visualization), evaluated the final performance, and made predictions. These steps form the core process you'll adapt for various machine learning tasks.
© 2025 ApX Machine Learning