The tf.data API constructs efficient input pipelines, transforming raw data sources into shuffled, batched, and prefetched streams ready for consumption. These pipelines integrate with the Keras training infrastructure. The Keras API, particularly the model.fit(), model.evaluate(), and model.predict() methods, is designed to work directly with tf.data.Dataset objects, making the integration straightforward and efficient.
When you pass a tf.data.Dataset object to model.fit(), Keras automatically iterates over the dataset to retrieve batches of data for each training step. This eliminates the need for manual batch iteration loops and integrates cleanly with Keras features like callbacks.
For training with model.fit(), your dataset should typically yield tuples of the form (inputs, targets). Keras expects each element yielded by the dataset iterator to represent one batch of data.
inputs: This can be a single tensor (for models with one input) or a tuple/dictionary of tensors (for models with multiple inputs). The structure must match the model's input signature.targets: Similarly, this can be a single tensor or a tuple/dictionary of tensors, corresponding to the model's output(s) and the loss function(s) being used.If your dataset yields batches like (feature_batch, label_batch), Keras will correctly map feature_batch to the model's inputs and label_batch to the expected outputs for calculating the loss.
Consider a dataset train_dataset created using methods like tf.data.Dataset.from_tensor_slices((features, labels)), followed by .shuffle(), .batch(), and .prefetch(). You can directly pass this dataset to model.fit():
# Assume 'model' is a compiled Keras model
# Assume 'train_dataset' yields (features_batch, labels_batch) tuples
# Assume 'val_dataset' yields (features_batch, labels_batch) tuples
history = model.fit(train_dataset, epochs=10, validation_data=val_dataset)
Keras handles the iteration, feeding batches to the training process automatically. The same applies to model.evaluate():
loss, accuracy = model.evaluate(val_dataset)
print(f"Validation Loss: {loss}, Validation Accuracy: {accuracy}")
For model.predict(), the dataset should yield only the input features. If the dataset yields (inputs, targets) tuples, Keras will simply ignore the targets part during prediction.
# Assume 'test_dataset' yields batches of features only, or (features, ...) tuples
predictions = model.predict(test_dataset)
steps_per_epoch and steps ArgumentsWhen using tf.data, you often work with datasets whose length might not be easily determined upfront, especially if you use transformations like repeat() (to loop infinitely over the data) or if the dataset is sourced from generators.
Finite Datasets: If Keras can determine the cardinality (number of batches) of your dataset (e.g., created from NumPy arrays or TFRecord files without repeat()), it will automatically run through the entire dataset once per epoch. You don't need to specify the number of steps.
Infinite Datasets or Unknown Cardinality: If your dataset is infinite (e.g., uses .repeat()) or its size cannot be determined, Keras doesn't know when one epoch ends. In this case, you must provide the steps_per_epoch argument to model.fit(). This integer value tells Keras how many batches to draw from the dataset to constitute one training epoch.
# Create a dataset and repeat it indefinitely
train_dataset_repeated = train_dataset.repeat()
# Define how many batches constitute one epoch
STEPS_PER_EPOCH = num_training_samples // BATCH_SIZE # Example calculation
history = model.fit(train_dataset_repeated,
epochs=10,
steps_per_epoch=STEPS_PER_EPOCH,
validation_data=val_dataset) # val_dataset usually finite
Similarly, model.evaluate() and model.predict() accept a steps argument. If you pass a dataset with unknown cardinality or an infinite dataset to these methods, you must specify the steps argument to indicate how many batches should be used for evaluation or prediction. If the dataset is finite and steps is not provided, they will run until the dataset is exhausted.
# Evaluate on a specific number of batches from the validation set
EVALUATION_STEPS = num_validation_samples // BATCH_SIZE # Example calculation
loss, accuracy = model.evaluate(val_dataset, steps=EVALUATION_STEPS)
# Predict on a specific number of batches from the test set
PREDICTION_STEPS = num_test_samples // BATCH_SIZE # Example calculation
predictions = model.predict(test_dataset, steps=PREDICTION_STEPS)
Choosing the right value for steps_per_epoch is important. A common practice is to set it such that the model sees roughly the equivalent of the entire training dataset once per epoch: steps_per_epoch = total_training_samples // batch_size.
Let's illustrate with a simple example using NumPy data.
import tensorflow as tf
import numpy as np
# 1. Generate some dummy data
num_samples = 1000
input_dim = 10
num_classes = 2
batch_size = 32
X_train = np.random.rand(num_samples, input_dim).astype(np.float32)
y_train = np.random.randint(0, num_classes, size=num_samples).astype(np.int32)
X_val = np.random.rand(200, input_dim).astype(np.float32)
y_val = np.random.randint(0, num_classes, size=200).astype(np.int32)
# 2. Create tf.data Datasets
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=num_samples).batch(batch_size).prefetch(tf.data.AUTOTUNE)
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_dataset = val_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
# 3. Build a simple Keras model
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(input_dim,)),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dense(num_classes, activation='softmax') # Use softmax for multi-class
])
# 4. Compile the model
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy', # Use sparse for integer labels
metrics=['accuracy'])
# 5. Train the model using the datasets
print("Training the model using tf.data.Dataset...")
# No steps_per_epoch needed as datasets are finite
history = model.fit(train_dataset, epochs=5, validation_data=val_dataset)
print("Training finished.")
# 6. Evaluate the model
print("\nEvaluating the model...")
loss, accuracy = model.evaluate(val_dataset) # No steps needed here either
print(f"Validation Loss: {loss:.4f}, Validation Accuracy: {accuracy:.4f}")
# 7. Make predictions (using a dataset derived from validation data for simplicity)
pred_dataset = tf.data.Dataset.from_tensor_slices(X_val).batch(batch_size)
print("\nMaking predictions...")
predictions = model.predict(pred_dataset)
print(f"Predictions shape: {predictions.shape}") # Shape: (num_val_samples, num_classes)
This example demonstrates how tf.data.Dataset objects plug into the standard Keras workflow. The shuffle, batch, and prefetch operations ensure data is efficiently prepared and fed to the model, maximizing hardware utilization, especially when combined with GPU or TPU acceleration. This integration is a fundamental aspect of building scalable machine learning workflows in TensorFlow.
Was this section helpful?
tf.data, covering key transformations like shuffle, batch, and prefetch.tf.keras.Model.fit() method, including its arguments and how it handles tf.data.Dataset objects.tf.data, with examples of integrating these datasets into Keras models.© 2026 ApX Machine LearningEngineered with