趋近智
tf.distribute.Strategy 概述虽然 model.fit() 为大多数标准训练场景提供了便捷的抽象,但常常会出现需要对训练过程进行更精细控制的情况。实现自定义训练循环可以赋予您这种控制能力,允许进行非标准梯度更新、复杂的指标计算,或与Keras回调不易处理的外部系统集成。这里将详细说明如何使用TensorFlow的核心组件构建这些循环。
当您需要时,可以选择自定义训练循环:
自定义训练循环主要协调这些组件之间的交互:
tf.keras.Model 实例(通过Sequential、函数式API或子类化创建)。tf.data.Dataset,提供输入特征和目标标签的批次数据。tf.keras.losses.SparseCategoricalCrossentropy),根据模型预测和真实标签计算损失值。tf.keras.optimizers.Optimizer 的一个实例(例如 tf.keras.optimizers.Adam),负责将梯度应用于模型的可训练变量。tf.GradientTape: 自动微分的引擎。它会记录在其上下文内执行的操作,允许您计算目标(通常是损失)相对于源变量(通常是模型的可训练变量)的梯度。tf.keras.metrics.Metric 实例(例如 tf.keras.metrics.Accuracy),用于在训练和评估期间跟踪性能。其基本结构涉及嵌套循环:一个用于迭代周期(epochs)的外部循环,以及一个用于每个迭代周期内批次(batches)的内部循环。
以下是内部(批次)循环中典型步骤的分解:
x_batch,y_batch)。tf.GradientTape 上下文。TensorFlow将监控在此块内访问的涉及可训练 tf.Variable 对象的运算。
with tf.GradientTape() as tape:
# 此处记录的操作
Dropout 或 BatchNormalization 等在训练和推断期间行为不同的层,请确保传递 training=True。
y_pred = model(x_batch, training=True)
loss_value = loss_fn(y_batch, y_pred)
# 如果损失函数涉及层添加的正则化项,
# 您可能需要添加 model.losses
loss_value += sum(model.losses)
grads = tape.gradient(loss_value, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
train_accuracy_metric.update_state(y_batch, y_pred)
train_loss_metric.update_state(loss_value)
外部(迭代周期)循环处理诸如遍历数据集以完成一次完整传递、在每个迭代周期开始时重置指标、记录结果以及可能运行一个验证循环等任务。
一个自定义训练循环的典型流程图,显示了迭代周期和批次迭代以及梯度计算和应用。
让我们用一个简单的例子来说明。假设您有一个编译好的Keras模型(model)、一个优化器(optimizer)、一个损失函数(loss_fn)、训练数据(train_dataset)以及指标(train_loss_metric,train_accuracy_metric)。
import tensorflow as tf
# 假设 model, optimizer, loss_fn, train_dataset 已定义
# 假设 train_loss_metric = tf.keras.metrics.Mean(name='train_loss')
# 假设 train_accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
epochs = 5
# 定义用于提升性能的训练步骤函数
@tf.function
def train_step(x_batch, y_batch):
with tf.GradientTape() as tape:
predictions = model(x_batch, training=True)
loss = loss_fn(y_batch, predictions)
# 如果存在正则化损失,请添加
loss += sum(model.losses)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# 更新指标
train_loss_metric.update_state(loss)
train_accuracy_metric.update_state(y_batch, predictions)
# 训练循环
for epoch in range(epochs):
print(f"\nStart of epoch {epoch+1}")
# 在每个迭代周期开始时重置指标
train_loss_metric.reset_state()
train_accuracy_metric.reset_state()
# 遍历数据集的批次
for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
train_step(x_batch_train, y_batch_train)
# 每N个批次记录一次(可选)
if step % 100 == 0:
print(f"Step {step}: Loss: {train_loss_metric.result():.4f}, Accuracy: {train_accuracy_metric.result():.4f}")
# 在每个迭代周期结束时显示指标
train_loss = train_loss_metric.result()
train_acc = train_accuracy_metric.result()
print(f"Epoch {epoch+1}: Training Loss: {train_loss:.4f}, Training Accuracy: {train_acc:.4f}")
# 可选:在此处运行一个验证循环,结构类似
# 但无需 GradientTape 和梯度应用。请记住调用
# model(x_val_batch, training=False)。
tf.function 提升性能请注意示例中应用于 train_step 函数的 @tf.function 装饰器。这对于性能很重要。TensorFlow会分析装饰函数内的Python代码,并生成一个优化的计算图。train_step 的后续调用会直接执行此图,从而在大多数操作中绕过较慢的Python解释器。
使用 @tf.function 时,请注意以下几点:
tf.Variable。tf.cond,tf.while_loop),而不是Python的 if/for/while,以确保它们是图的一部分。许多层,尤其是 tf.keras.layers.BatchNormalization 和 tf.keras.layers.Dropout,在训练和推断期间行为不同。批归一化在训练期间更新其移动均值和方差,但在推断期间使用它们进行归一化。Dropout在训练期间随机将激活设为零,但在推断期间不活跃。
在调用模型时正确传入 training 参数非常重要:
model(inputs, training=True):在训练步骤期间的 GradientTape 上下文内。model(inputs, training=False):在执行验证或训练后进行预测时。忘记这一点可能导致结果不正确或模型无法正常训练。
model.fit()在 model.fit() 和自定义循环之间进行选择涉及便利性与控制权之间的权衡:
model.fit():
掌握自定义训练循环有助于您实现几乎任何训练算法,突破标准工作流,从而精确地定制TensorFlow以满足您的高级建模要求。
这部分内容有帮助吗?
tf.GradientTape用于自动微分的原理和用法,这是在任何自定义训练循环中计算梯度的基础机制。@tf.function优化TensorFlow代码以提升性能的见解,这是实现自定义训练步骤时的一个重要考量。© 2026 ApX Machine Learning用心打造