While loading data directly from NumPy arrays or Python generators works well for smaller datasets, it can become inefficient when dealing with very large datasets that may not fit into memory or involve slow data loading from numerous small files. Reading data involves overhead like file opening, closing, and seeking, which can dominate processing time if not managed carefully.
TensorFlow provides a specific file format optimized for its data pipelines: TFRecord. A TFRecord file stores your data as a sequence of binary records. Serializing data into this format and reading it back sequentially can significantly improve I/O performance, especially when working with large datasets stored on network file systems or spinning disks. It's TensorFlow's recommended format for training data.
At its core, a TFRecord file (.tfrecord
or .tfrec
) contains a sequence of variable-length binary strings. Each string represents one data record (e.g., one image and its label). TensorFlow uses Protocol Buffers (protobufs) internally to structure these records. Specifically, each record is typically a serialized tf.train.Example
message.
A tf.train.Example
is a flexible container designed to hold key-value pairs, where the keys are strings (feature names) and the values are tf.train.Feature
messages.
tf.train.Example {
features: tf.train.Features {
feature: map<string, tf.train.Feature> {
"image_raw": tf.train.Feature { bytes_list: tf.train.BytesList {...} },
"label": tf.train.Feature { int64_list: tf.train.Int64List {...} },
"height": tf.train.Feature { int64_list: tf.train.Int64List {...} },
"width": tf.train.Feature { int64_list: tf.train.Int64List {...} },
"caption": tf.train.Feature { bytes_list: tf.train.BytesList {...} }
}
}
}
The tf.train.Feature
message itself can hold one of three list types:
tf.train.BytesList
: For binary strings (like serialized image data) or UTF-8 strings.tf.train.FloatList
: For floating-point values (float32, float64).tf.train.Int64List
: For integer values (bool, enum, int32, uint32, int64, uint64).This structure allows you to store diverse data types within a single record format. For instance, an image might be stored as raw bytes (BytesList
), its label as an integer (Int64List
), and perhaps bounding box coordinates as floats (FloatList
).
To create a TFRecord file, you need to convert your raw data into tf.train.Example
protocol buffers and then write these serialized protos to a file using tf.io.TFRecordWriter
.
Let's illustrate with an example where we have image data (as byte strings) and corresponding integer labels.
First, we need helper functions to easily create tf.train.Feature
messages from standard Python/NumPy types:
import tensorflow as tf
import numpy as np
# Helper functions to create tf.train.Feature
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))): # If it's a Tensor
value = value.numpy() # Get its value
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
# Example: Assume you have image bytes and a label
# Replace with your actual data loading
image_bytes = tf.io.encode_jpeg(np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)).numpy()
label = 5 # Example class label
# Create a dictionary mapping feature names to Feature protos
feature_dict = {
'image_raw': _bytes_feature(image_bytes),
'label': _int64_feature(label),
}
# Create a Features message using the dictionary
features = tf.train.Features(feature=feature_dict)
# Create an Example message using the Features message
example_proto = tf.train.Example(features=features)
# Serialize the Example message to a binary string
serialized_example = example_proto.SerializeToString()
# Now, write this (and potentially many others) to a TFRecord file
tfrecord_filename = 'data.tfrecord'
# Use tf.io.TFRecordWriter to write the serialized example
with tf.io.TFRecordWriter(tfrecord_filename) as writer:
# Write multiple examples in a loop for a real dataset
writer.write(serialized_example)
# writer.write(another_serialized_example)
# ...
In a typical workflow, you would iterate through your dataset (e.g., image file paths and labels), load each sample, convert it to a tf.train.Example
using the structure defined above, serialize it, and write it to the TFRecordWriter
. You might also split your data into multiple TFRecord files (shards) for better parallel reading later.
Reading data back involves using tf.data.TFRecordDataset
and parsing the serialized tf.train.Example
messages.
Create a TFRecordDataset
: This dataset reads the raw, serialized protocol buffer strings from one or more TFRecord files.
# Can provide a list of filenames for sharded data
raw_dataset = tf.data.TFRecordDataset([tfrecord_filename])
# raw_dataset now yields serialized tf.train.Example strings
for raw_record in raw_dataset.take(1):
print(repr(raw_record))
Define a Parsing Function: Since the dataset yields serialized strings, you need to parse them back into tensors. This requires defining the structure of the data you expect to find in each tf.train.Example
. You create a feature_description
dictionary mapping feature names to parsing instructions (tf.io.FixedLenFeature
for fixed-size features or tf.io.VarLenFeature
for variable-length features).
# Define the features structure you used during writing
feature_description = {
'image_raw': tf.io.FixedLenFeature([], tf.string), # 0-D string tensor
'label': tf.io.FixedLenFeature([], tf.int64), # 0-D int64 tensor
}
# Function to parse a single tf.train.Example proto
def _parse_function(example_proto):
# Parse the input `tf.train.Example` proto using the dictionary above.
return tf.io.parse_single_example(example_proto, feature_description)
Map the Parsing Function: Apply the parsing function to each element in the raw dataset using dataset.map()
.
parsed_dataset = raw_dataset.map(_parse_function)
# parsed_dataset now yields dictionaries of tensors
for features in parsed_dataset.take(1):
print(f"Label: {features['label'].numpy()}")
# Image is still raw bytes, needs decoding
image_tensor = tf.io.decode_jpeg(features['image_raw'])
print(f"Image shape: {image_tensor.shape}")
Further Processing (Decoding, Preprocessing): Often, data stored in TFRecords (like image bytes) needs further decoding or preprocessing. You can add these steps within the map
function or chain additional map
calls.
def _decode_and_preprocess(features):
# Decode the JPEG image
image = tf.io.decode_jpeg(features['image_raw'], channels=3)
# Resize or apply other preprocessing
image = tf.image.resize(image, [128, 128])
image = tf.cast(image, tf.float32) / 255.0 # Normalize
label = features['label']
return image, label
# Apply decoding and preprocessing
processed_dataset = parsed_dataset.map(_decode_and_preprocess)
# Now the dataset yields (image_tensor, label_tensor) tuples
for image, label in processed_dataset.take(1):
print(f"Processed Image shape: {image.shape}, Label: {label.numpy()}")
tf.data
PipelineThe real power comes when you combine TFRecord reading with other tf.data
transformations like shuffling, batching, and prefetching:
# Assuming you have multiple TFRecord files (shards)
filenames = tf.data.Dataset.list_files("data_*.tfrecord")
# Interleave reading from multiple files for better shuffling
dataset = filenames.interleave(tf.data.TFRecordDataset,
cycle_length=tf.data.AUTOTUNE,
num_parallel_calls=tf.data.AUTOTUNE)
# Shuffle, parse, decode, preprocess, batch, prefetch
BUFFER_SIZE = 10000
BATCH_SIZE = 32
dataset = dataset.shuffle(BUFFER_SIZE)
dataset = dataset.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.map(_decode_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
print(dataset)
# This dataset is now ready to be passed to model.fit()
# model.fit(dataset, epochs=10)
Using tf.data.AUTOTUNE
allows TensorFlow to dynamically tune the level of parallelism for map operations and the prefetch buffer size based on available system resources, simplifying performance optimization.
TFRecord is particularly beneficial when:
While creating TFRecord files adds an initial data preparation step, the potential performance gains during training, especially for large datasets and long training runs, often justify the effort. For smaller datasets that comfortably fit in memory, loading directly from NumPy arrays might be simpler and sufficient.
© 2025 ApX Machine Learning