趋近智
训练大型机器学习模型时,经常会遇到加速器内存和计算时间的限制。混合精度训练是一种广泛使用的技术,它通过巧妙地在部分计算中使用低精度浮点数(如16位浮点数),大幅缓解了这些限制。这种方法几乎能将激活和梯度的内存消耗减半,并在兼容硬件上显著加速训练,同时不大会损害模型精度。
混合精度训练的原理是,大部分计算,特别是前向和后向传播中耗时的矩阵乘法和卷积,使用16位浮点数执行。同时,为了确保数值稳定性,模型权重的原始副本以及可能的优化器状态等一些必要组件,会保持在标准32位精度(float32)。
JAX 通过 NumPy 支持两种主要的16位浮点格式:
jnp.float16 (半精度): 这种格式符合IEEE 754的16位浮点数标准。它使用1位表示符号,5位表示指数,10位表示分数(尾数)。
float16优化的硬件(例如NVIDIA Tensor Cores)上提供显著加速。与float32相比,内存使用量减半。float32相比,动态范围非常有限。这意味着它容易受到数值下溢(梯度变为零)和上溢(梯度变为无穷或NaN)的影响。需要仔细实现,通常涉及损失缩放。jnp.bfloat16 (Brain浮点): 由Google Brain开发,这种格式使用1位表示符号,8位表示指数(与float32相同),7位表示分数。
float32相同的动态范围,使得它极少出现上溢和下溢问题。通常比float16更易于使用,因为它通常不需要损失缩放。提供类似的内存节省,并能在兼容硬件(尤其是TPU和较新的GPU)上提供加速。float16(分数位数更少)。虽然对于深度学习通常足够,但这可能会影响对高度敏感模型的收敛性。以下是一个比较:
| 特性 | float32 (单精度) |
float16 (半精度) |
bfloat16 (Brain) |
|---|---|---|---|
| 总位数 | 32 | 16 | 16 |
| 指数位数 | 8 | 5 | 8 |
| 分数位数 | 23 | 10 | 7 |
| 动态范围 | 宽 | 窄 | 宽 (与float32相似) |
| 精度 | 高 | 中 | 低 |
| 需要损失缩放? | 否 | 通常需要 | 通常不需要 |
考虑到其更宽的动态范围和更简单的易用性,当硬件支持时(在TPU和较新的NVIDIA GPU如Ampere及后续型号上很常见),bfloat16通常是混合精度的首选。如果只有float16得到高效支持,则需要结合损失缩放进行仔细实现。
标准策略包括将模型参数的原始副本保持为float32,同时使用float16或bfloat16执行大多数计算。基于JAX构建的高级神经网络库,如Flax或Haiku,通常提供方便的抽象来管理这一点。
例如,Flax 允许您为参数(param_dtype)和计算(dtype)指定不同的数据类型。bfloat16混合精度的常见设置包括:
param_dtype=jnp.float32 初始化参数。apply 方法,将输入转换为 bfloat16,并为中间计算指定 dtype=jnp.bfloat16。import jax
import jax.numpy as jnp
import flax.linen as nn
# 假设模型使用Flax定义
class SimpleDense(nn.Module):
features: int
param_dtype: jnp.dtype = jnp.float32 # 主权重使用float32
dtype: jnp.dtype = jnp.bfloat16 # 计算使用bfloat16
@nn.compact
def __call__(self, x):
# 输入x预期为bfloat16或将被转换
x = x.astype(self.dtype)
# kernel将是float32,但矩阵乘法会提升为计算数据类型(bfloat16)
# 结果将是bfloat16
y = nn.Dense(features=self.features,
param_dtype=self.param_dtype,
dtype=self.dtype, # 如果需要,显式设置计算数据类型
name='dense_layer')(x)
# 后续操作在bfloat16中执行
return nn.relu(y)
# --- 初始化 ---
key = jax.random.PRNGKey(0)
input_shape = (1, 10)
dummy_input = jnp.zeros(input_shape, dtype=jnp.bfloat16) # 输入数据类型
model = SimpleDense(features=5)
# 初始化float32参数
params = model.init(key, dummy_input)['params']
# --- 前向传播 ---
# 在调用apply之前,输入应转换为计算数据类型
output = model.apply({'params': params}, dummy_input)
print(f"Input dtype: {dummy_input.dtype}")
print(f"Parameter dtype (example): {jax.tree.leaves(params)[0].dtype}")
print(f"Output dtype: {output.dtype}")
# 预期输出:
# Input dtype: bfloat16
# Parameter dtype (example): float32
# Output dtype: bfloat16
一个例子,说明了Flax模块如何处理
bfloat16混合精度的不同参数和计算数据类型。
该库处理了转换输入和确保计算使用指定dtype的细节,同时参数保持在float32。计算出的梯度通常将是计算dtype(在本例中为bfloat16)。优化器随后使用这些可能较低精度的梯度来更新float32主参数。
如果不使用高级库,您需要手动管理类型转换:
# 不使用库的示例
def predict(params_f32, inputs_bf16):
# 假设params_f32是float32参数的pytree
# 假设inputs_bf16已经是bfloat16
activations = inputs_bf16
for W_f32, b_f32 in params_f32: # 遍历层
# 将权重转换为bfloat16进行计算
W_bf16 = W_f32.astype(jnp.bfloat16)
b_bf16 = b_f32.astype(jnp.bfloat16)
# 在bfloat16中执行计算
outputs = jnp.dot(activations, W_bf16) + b_bf16
activations = jax.nn.relu(outputs)
return activations # 输出是bfloat16
# 在梯度计算和更新期间:
# 1. 计算梯度(可能为bfloat16)
# 2. 优化器使用这些梯度更新float32主参数
这需要更细致的处理,以确保在整个模型和训练循环中正确管理类型。
float16的损失缩放使用jnp.float16时,其有限的动态范围经常导致幅度较小的梯度变为零(下溢)。为了防止这种情况发生:
float16的可表示范围。float32主权重之前,将计算出的梯度除以相同的缩放因子 S。
梯度=S∇θ缩放后的损失float32主权重。缩放因子 S 可以静态选择或动态调整。动态损失缩放涉及从一个较大的 S 开始,如果在训练期间检测到溢出(NaN或Inf梯度),则减小它,同时,如果梯度在一段时间内保持稳定,则可能增加它。这增加了复杂性,但有助于找到最佳缩放。Flax和Optax等库通常提供用于管理损失缩放的实用工具。
bfloat16通常可靠。float16需要仔细实现和损失缩放。确保关键操作,例如归一化层中的方差计算或最终损失计算,必要时仍保持在float32中。float32训练相当的精度,但可能会出现微小差异。验证最终模型的性能是一个好习惯。float16)是常见的调试步骤。混合精度训练是构建大规模模型的重要工具。通过明智地结合float32以保持稳定性,并使用bfloat16或float16以节省内存和提高速度,您可以使用JAX及其生态系统更高效地训练更大、能力更强的模型。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造