趋近智
使用标准32位浮点精度(FP32)训练大型Transformer模型会消耗大量计算资源并占用大量内存。混合精度训练提供了一个有效的办法,它在较低精度格式(如16位浮点数FP16或BF16)下执行某些操作,同时将主权重 (weight)等主要部分保留在FP32中。这种方式能显著加快计算速度并减少内存占用,通常对最终模型的精度影响很小或没有影响。
现代硬件加速器,特别是配备NVIDIA Tensor Cores等专用单元的GPU,在较低精度(FP16或BF16)下执行矩阵乘法操作时,相比FP32能提供很大的性能提升。以16位精度执行前向和后向传播的部分步骤,直接意味着更快的训练迭代。
此外,与FP32相比,使用16位格式可将存储激活值、梯度和可能的模型权重 (weight)所需的内存减半。这种内存节省使得以下成为可能:
主要思想是发挥较低精度在速度和内存上的优势,用于大部分计算,同时通过策略性地使用FP32来保持数值稳定性。虽然实现方式略有不同,并且通常由深度学习 (deep learning)框架自动处理,但典型过程包含多个组成部分:
现代框架通常使用动态损失缩放,其中缩放因子在训练期间自动调整。如果检测到溢出(梯度变为Inf或NaN),缩放因子会减小。如果梯度在一定步数内保持稳定,缩放因子可能会增加,以更好地使用FP16的动态范围。
使用两种常见的16位格式:
选择通常取决于硬件是否可用。如果两者都支持,BF16可能会提供稍微简单的训练设置,因为它在数值范围上具有鲁棒性,而FP16在其更高精度有利的场景下可能会略微更好,前提是使用了有效的损失缩放。
深度学习 (deep learning)框架提供方便的API,通过最少的代码改动来实现混合精度训练。
PyTorch: 使用torch.cuda.amp(自动混合精度)模块。它提供上下文 (context)管理器(autocast)和梯度缩放工具(GradScaler)。
# 示例草图 (PyTorch)
import torch
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
model = YourTransformerModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=...)
for inputs, targets in dataloader:
inputs, targets = inputs.cuda(), targets.cuda()
optimizer.zero_grad()
# 将上下文管理器中的操作转换为FP16/BF16
with autocast(dtype=torch.float16): # 或者如果支持/需要,使用torch.bfloat16
outputs = model(inputs)
loss = compute_loss(outputs, targets)
# 缩放损失。在缩放后的损失上调用backward()以创建缩放后的梯度。
scaler.scale(loss).backward()
# scaler.step()首先取消优化器分配参数的梯度缩放。
# 如果梯度不是inf/NaN,则调用optimizer.step()。
scaler.step(optimizer)
# 更新下一个迭代的缩放因子。
scaler.update()
TensorFlow: 使用tf.keras.mixed_precision API。您可以设置全局策略或按层应用它。当使用model.fit时,TensorFlow会自动处理损失缩放。
# 示例草图 (TensorFlow)
import tensorflow as tf
# 设置全局策略(例如,'mixed_float16')
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)
# 像往常一样构建模型
inputs = tf.keras.Input(...)
# ... 定义Transformer层 ...
outputs = tf.keras.layers.Dense(vocab_size, activation='softmax', dtype='float32')(x) # 输出层通常保留在FP32中
model = tf.keras.Model(inputs=inputs, outputs=outputs)
optimizer = tf.keras.optimizers.AdamW(...)
# 当使用混合策略时,model.fit会自动处理损失缩放
model.compile(optimizer=optimizer, loss='...', metrics=[...])
model.fit(dataset, epochs=...)
虽然混合精度训练非常有效,但建议监控训练稳定性,并偶尔将最终模型性能与基线FP32运行进行比较,尤其是在首次将其应用于新架构或任务时。某些数值操作,例如大范围约简或需要高精度的计算,有时可能会因框架启发式算法而从自动转换中排除,或可能需要手动配置以保留在FP32中。
示意性比较,显示混合精度训练可能带来的速度提升(例如,快1.8倍)和内存节省(例如,减少45%)。实际收益取决于模型、硬件和具体的实现方式。
混合精度训练已成为深度学习从业者工具箱中的一项标准技术,特别是对于Transformer等资源密集型模型。通过智能地结合较低精度计算和保持数值稳定性的机制,它使得训练迭代更快,并且在现有硬件限制下使用更大、能力更强的模型成为可能。
这部分内容有帮助吗?
torch.cuda.amp模块实现混合精度训练的实际示例和详细信息。© 2026 ApX Machine LearningAI伦理与透明度•