After compiling your model with an optimizer, loss function, and metrics, the next step is to train it on your data. This is where the model.fit()
method comes into play. It orchestrates the training process, iterating over your dataset, calculating the loss, computing gradients, and updating the model's weights using the chosen optimizer.
model.fit()
MethodThe model.fit()
function is the workhorse of training in Keras. It takes your training data, target labels, and various configuration parameters to manage the learning process. At its core, fit
repeatedly performs the following steps for a specified number of iterations (epochs):
fit()
You can supply training data to model.fit()
in several ways:
x
) and target labels (y
).tf.data.Dataset
Objects: For large datasets that may not fit in memory, or when you need sophisticated input processing (like prefetching, caching, or complex transformations), using a tf.data.Dataset
object is the recommended approach. We will cover tf.data
in detail in the next chapter.Let's look at the most important arguments for model.fit()
:
x
: Input data. Can be a NumPy array, a TensorFlow tensor, or a tf.data.Dataset
. If it's a dataset, y
should not be provided (as labels are expected to be included in the dataset).y
: Target data (labels). Should be a NumPy array or TensorFlow tensor if x
is an array/tensor. Omit this if x
is a tf.data.Dataset
yielding tuples like (features, labels)
.batch_size
: An integer specifying the number of samples per gradient update. Training is typically performed in mini-batches rather than processing the entire dataset at once. This improves computational efficiency and can help the optimization process generalize better. Common batch sizes range from 32 to 256, but the optimal value depends on the dataset size, model complexity, and available memory (larger batches require more memory). If you provide data as a tf.data.Dataset
, the batching should ideally be handled by the dataset itself, and you can leave batch_size=None
in fit
.epochs
: An integer defining the number of times the learning algorithm will work through the entire training dataset. One epoch means that every sample in the training dataset has had an opportunity to update the internal model parameters. Training typically requires multiple epochs.validation_data
: Data on which to evaluate the loss and any model metrics at the end of each epoch. This should typically be a separate validation set that the model does not train on. Providing validation data helps you monitor for overfitting. It's usually passed as a tuple (x_val, y_val)
of NumPy arrays or tensors, or as a tf.data.Dataset
object.validation_split
: An alternative to validation_data
. A float between 0 and 1. If specified, fit
will automatically reserve this fraction of the training data for validation and won't train on it. The split occurs before shuffling. This is convenient for quick validation but less robust than using a dedicated validation set, especially if your data has inherent ordering. You cannot use both validation_data
and validation_split
simultaneously.shuffle
: A boolean (defaulting to True
when using array/tensor data) indicating whether to shuffle the training data before each epoch. Shuffling helps prevent the model from learning the order of the data and improves generalization. When using tf.data.Dataset
, shuffling should ideally be handled within the dataset pipeline (e.g., using dataset.shuffle()
).callbacks
: A list of keras.callbacks.Callback
instances. Callbacks are utilities that can be invoked at different stages of the training process (e.g., end of epoch, start of batch) to perform actions like saving the model, stopping training early, or logging to TensorBoard. We'll discuss callbacks in more detail later in this chapter.Let's assume you have a compiled model
, training data x_train
, y_train
, and validation data x_val
, y_val
as NumPy arrays. You can start the training like this:
import tensorflow as tf
# Assume model is already built and compiled
# Assume x_train, y_train, x_val, y_val are NumPy arrays
print("Starting training...")
history = model.fit(
x_train,
y_train,
batch_size=64,
epochs=20,
validation_data=(x_val, y_val)
)
print("Training finished.")
# The 'history' object contains training logs
print("Validation accuracy per epoch:", history.history['val_accuracy'])
In this example, the model will train for 20 epochs using mini-batches of 64 samples. The training data will be shuffled before each epoch. After each epoch, the model's loss and metrics will be evaluated on (x_val, y_val)
.
tf.data.Dataset
If you have prepared your data using tf.data
, assuming train_dataset
yields (features, labels)
tuples and is already batched, and you have a similar val_dataset
:
import tensorflow as tf
# Assume model is already built and compiled
# Assume train_dataset and val_dataset are tf.data.Dataset objects
# where train_dataset.element_spec is (tf.TensorSpec(shape=(None, ...), dtype=tf.float32),
# tf.TensorSpec(shape=(None, ...), dtype=tf.int32))
# and train_dataset is already batched and shuffled.
print("Starting training with tf.data...")
history = model.fit(
train_dataset, # No y argument needed
epochs=20,
validation_data=val_dataset
# batch_size is often omitted here as the dataset handles batching
# shuffle is often omitted here as the dataset handles shuffling
)
print("Training finished.")
# Access history similarly
print("Validation loss per epoch:", history.history['val_loss'])
Notice that when using tf.data.Dataset
, you typically don't provide y
, batch_size
, or shuffle
to fit
, as these concerns are managed within the dataset pipeline itself.
The model.fit()
method returns a History
object. This object has a history
attribute, which is a dictionary containing the recorded loss and metric values for each epoch. The keys are the names of the metrics (e.g., 'loss'
, 'accuracy'
, 'val_loss'
, 'val_accuracy'
), and the values are lists containing the metric value at the end of each epoch.
This history is extremely useful for analyzing the training process, such as plotting learning curves to diagnose issues like overfitting or underfitting.
import matplotlib.pyplot as plt
# Assuming 'history' is the object returned by model.fit()
train_loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(1, len(train_loss) + 1)
plt.figure(figsize=(8, 5))
plt.plot(epochs_range, train_loss, label='Training Loss', color='#1c7ed6') # blue
plt.plot(epochs_range, val_loss, label='Validation Loss', color='#f76707') # orange
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()
A plot showing training loss decreasing steadily while validation loss decreases initially but then starts to increase, a classic sign of overfitting.
Using model.fit()
is the standard way to train Keras models. By understanding its parameters and how to provide data, you can effectively manage the training loop for a wide variety of machine learning tasks. Remember to monitor your validation metrics closely via the History
object or through callbacks to build effective models.
© 2025 ApX Machine Learning