趋近智
虽然直接从 NumPy 数组或 Python 生成器加载数据对于较小的数据集很有效,但当处理非常大的数据集时,如果它们无法完全放入内存或涉及从大量小文件缓慢加载数据,这种方式可能会变得效率低下。读取数据涉及文件打开、关闭和查找等额外开销,如果管理不当,这些开销可能会占据处理时间的大部分。
TensorFlow 提供了一种专门为其数据管道优化的文件格式:TFRecord。TFRecord 文件将你的数据存储为一系列二进制记录。将数据序列化为这种格式并顺序读回,可以大幅提高 I/O 性能,尤其是在处理存储在网络文件系统或机械硬盘上的大型数据集时。它是 TensorFlow 推荐的训练数据格式。
TFRecord 文件(.tfrecord 或 .tfrec)本质上包含一系列变长二进制字符串。每个字符串代表一个数据记录(例如,一张图片及其标签)。TensorFlow 内部使用 Protocol Buffers (protobufs) 来组织这些记录。具体来说,每个记录通常是一个序列化的 tf.train.Example 消息。
tf.train.Example 是一个灵活的容器,设计用于存放键值对,键是字符串(特征名称),值是 tf.train.Feature 消息。
tf.train.Example {
features: tf.train.Features {
feature: map<string, tf.train.Feature> {
"image_raw": tf.train.Feature { bytes_list: tf.train.BytesList {...} },
"label": tf.train.Feature { int64_list: tf.train.Int64List {...} },
"height": tf.train.Feature { int64_list: tf.train.Int64List {...} },
"width": tf.train.Feature { int64_list: tf.train.Int64List {...} },
"caption": tf.train.Feature { bytes_list: tf.train.BytesList {...} }
}
}
}
tf.train.Feature 消息本身可以包含以下三种列表类型之一:
tf.train.BytesList: 用于二进制字符串(如序列化图像数据)或 UTF-8 字符串。tf.train.FloatList: 用于浮点数值(float32, float64)。tf.train.Int64List: 用于整数值(布尔型、枚举型、int32、uint32、int64、uint64)。这种结构允许你在单一记录格式中存储不同类型的数据。例如,一张图片可以存储为原始字节 (BytesList),其标签存储为整数 (Int64List),而边界框坐标则可能存储为浮点数 (FloatList)。
要创建一个 TFRecord 文件,你需要将原始数据转换为 tf.train.Example 协议缓冲区,然后使用 tf.io.TFRecordWriter 将这些序列化的协议写入文件。
我们通过一个例子来说明,假设我们有图像数据(作为字节字符串)和相应的整数标签。
首先,我们需要辅助函数来轻松地从标准 Python/NumPy 类型创建 tf.train.Feature 消息:
import tensorflow as tf
import numpy as np
# Helper functions to create tf.train.Feature
def _bytes_feature(value):
"""从字符串/字节返回一个 bytes_list。"""
if isinstance(value, type(tf.constant(0))): # 如果它是一个张量
value = value.numpy() # 获取其值
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""从浮点数/双精度数返回一个 float_list。"""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
"""从布尔型/枚举型/整型/无符号整型返回一个 int64_list。"""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
# 示例:假设你有图像字节和标签
# 替换为你的实际数据加载逻辑
image_bytes = tf.io.encode_jpeg(np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)).numpy()
label = 5 # 示例类别标签
# 创建一个将特征名称映射到 Feature 协议的字典
feature_dict = {
'image_raw': _bytes_feature(image_bytes),
'label': _int64_feature(label),
}
# 使用字典创建一个 Features 消息
features = tf.train.Features(feature=feature_dict)
# 使用 Features 消息创建一个 Example 消息
example_proto = tf.train.Example(features=features)
# 将 Example 消息序列化为二进制字符串
serialized_example = example_proto.SerializeToString()
# 现在,将此(以及可能许多其他)写入 TFRecord 文件
tfrecord_filename = 'data.tfrecord'
# 使用 tf.io.TFRecordWriter 写入序列化的示例
with tf.io.TFRecordWriter(tfrecord_filename) as writer:
# 在循环中写入多个示例以处理真实数据集
writer.write(serialized_example)
# writer.write(another_serialized_example)
# ...
在典型的工作流程中,你会迭代你的数据集(例如,图像文件路径和标签),加载每个样本,使用上面定义的结构将其转换为 tf.train.Example,然后将其序列化并写入 TFRecordWriter。你还可以将数据分割成多个 TFRecord 文件(分片),以便日后更好地并行读取。
读回数据涉及使用 tf.data.TFRecordDataset 并解析序列化的 tf.train.Example 消息。
创建 TFRecordDataset:此数据集从一个或多个 TFRecord 文件中读取原始的、序列化的协议缓冲区字符串。
# 可以提供一个文件名列表,用于分片数据
raw_dataset = tf.data.TFRecordDataset([tfrecord_filename])
# raw_dataset 现在生成序列化的 tf.train.Example 字符串
for raw_record in raw_dataset.take(1):
print(repr(raw_record))
定义解析函数:由于数据集生成序列化的字符串,你需要将其解析回张量。这需要定义你期望在每个 tf.train.Example 中找到的数据结构。你创建一个 feature_description 字典,将特征名称映射到解析指令(tf.io.FixedLenFeature 用于固定大小的特征,或 tf.io.VarLenFeature 用于变长特征)。
# 定义你在写入时使用的特征结构
feature_description = {
'image_raw': tf.io.FixedLenFeature([], tf.string), # 0维字符串张量
'label': tf.io.FixedLenFeature([], tf.int64), # 0维 int64 张量
}
# 用于解析单个 tf.train.Example 协议的函数
def _parse_function(example_proto):
# 使用上述字典解析输入的 `tf.train.Example` 协议。
return tf.io.parse_single_example(example_proto, feature_description)
映射解析函数:使用 dataset.map() 将解析函数应用于原始数据集中的每个元素。
parsed_dataset = raw_dataset.map(_parse_function)
# parsed_dataset 现在生成张量字典
for features in parsed_dataset.take(1):
print(f"Label: {features['label'].numpy()}")
# 图像仍是原始字节,需要解码
image_tensor = tf.io.decode_jpeg(features['image_raw'])
print(f"Image shape: {image_tensor.shape}")
进一步处理(解码、预处理):通常,存储在 TFRecord 中的数据(如图像字节)需要进一步解码或预处理。你可以在 map 函数内部添加这些步骤,或者链接额外的 map 调用。
def _decode_and_preprocess(features):
# 解码 JPEG 图像
image = tf.io.decode_jpeg(features['image_raw'], channels=3)
# 调整大小或应用其他预处理
image = tf.image.resize(image, [128, 128])
image = tf.cast(image, tf.float32) / 255.0 # 归一化
label = features['label']
return image, label
# 应用解码和预处理
processed_dataset = parsed_dataset.map(_decode_and_preprocess)
# 现在数据集生成 (image_tensor, label_tensor) 元组
for image, label in processed_dataset.take(1):
print(f"Processed Image shape: {image.shape}, Label: {label.numpy()}")
tf.data 管道集成真正的优势在于你将 TFRecord 读取与其他 tf.data 转换(如洗牌、批处理和预取)结合起来时:
# 假设你有多个 TFRecord 文件(分片)
filenames = tf.data.Dataset.list_files("data_*.tfrecord")
# 交错读取多个文件以更好地洗牌
dataset = filenames.interleave(tf.data.TFRecordDataset,
cycle_length=tf.data.AUTOTUNE,
num_parallel_calls=tf.data.AUTOTUNE)
# 洗牌、解析、解码、预处理、批处理、预取
BUFFER_SIZE = 10000
BATCH_SIZE = 32
dataset = dataset.shuffle(BUFFER_SIZE)
dataset = dataset.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.map(_decode_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
print(dataset)
# 此数据集现在已准备好传递给 model.fit()
# model.fit(dataset, epochs=10)
使用 tf.data.AUTOTUNE 允许 TensorFlow 根据可用系统资源动态调整 map 操作的并行度以及预取缓冲区大小,从而简化性能优化。
TFRecord 在以下情况下特别有益:
虽然创建 TFRecord 文件会增加一个初始数据准备步骤,但训练期间潜在的性能提升,特别是对于大型数据集和长时间训练运行,通常值得付出。对于可以轻松放入内存的较小数据集,直接从 NumPy 数组加载可能更简单和足够。
这部分内容有帮助吗?
tf.train.Example协议缓冲区进行高效数据存储的官方指南。tf.data API,包括TFRecordDataset、数据集转换、性能优化和AUTOTUNE等概念,用于构建稳健的数据输入管道。© 2026 ApX Machine Learning用心打造