趋近智
训练大型神经网络 (neural network)常需的批处理大小会超出单个加速器(GPU 或 TPU)的可用内存。在多个设备间分发计算(例如使用 pmap)有助于缓解此问题,但有时即使是每个设备所需的批处理大小,为了稳定训练或获得最佳性能,也可能超出设备内存。梯度累积为应对此问题提供了一个直接方法。
其主要思路是通过顺序处理多个较小的批次来模拟大批次,累加它们的梯度,然后利用汇总的梯度信息执行一次优化器步骤。这有效地将用于权重 (weight)更新的批处理大小与在任意时刻必须适应内存的批处理大小分离开来。
假设您希望使用有效批处理大小 进行训练,但您的加速器只能处理更小的批处理大小,我们称之为微批处理大小 ,其中 。梯度累积通过执行以下步骤实现此目标:
此过程计算梯度:
这个平均梯度近似于使用完整有效批处理大小 计算出的梯度。主要优点是在梯度计算阶段,加速器内存中一次只需要保留一个微批次。
在 JAX 中实现梯度累积通常涉及修改训练步骤函数。我们不再一次性计算梯度并应用更新,而是分离这些步骤并引入一个循环。
让我们考虑一个典型的 JAX 训练步骤函数,它接收模型状态(参数 (parameter)、优化器状态)和一批数据,计算损失和梯度,然后返回更新后的状态和指标。
import jax
import jax.numpy as jnp
import optax # 示例优化器库
# 假设 'model' 和 'loss_fn' 已在其他地方定义
# 'params' 是模型参数
# 'opt_state' 是优化器状态
@jax.jit
def train_step(params, opt_state, batch):
"""执行不带梯度累积的单个训练步骤。"""
def compute_loss(p):
logits = model.apply({'params': p}, batch['image'])
loss = loss_fn(logits, batch['label'])
return loss
grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
metrics = {'loss': loss}
return params, opt_state, metrics
为了纳入梯度累积,我们需要管理累积的梯度并循环处理微批次。
# 假设 accumulation_steps = N (大于 1 的整数)
@jax.jit
def micro_batch_step(params, micro_batch):
"""计算单个微批次的梯度。"""
def compute_loss(p):
logits = model.apply({'params': p}, micro_batch['image'])
loss = loss_fn(logits, micro_batch['label'])
return loss, logits # 也返回 logits 用于可能的指标计算
# 使用 value_and_grad 获取损失和梯度
(loss, _), grads = jax.value_and_grad(compute_loss, has_aux=True)(params)
# 注意:如果跨步骤求平均,此处返回损失可能比较复杂。
# 通常,指标是根据完整的有效批次计算的。
return grads, loss
# 此函数现在将处理累积循环和优化器更新
# 它通常不会完全 JIT 编译,因为循环处理数据加载。
# 但是,micro_batch_step 是 JIT 编译的。
def accumulated_train_step(params, opt_state, data_iterator, accumulation_steps):
"""执行带梯度累积的训练步骤。"""
# 1. 初始化累积梯度(形状与参数相同,均为零)
accumulated_grads = jax.tree_util.tree_map(jnp.zeros_like, params)
total_loss = 0.0
# 2. 微批次循环
for _ in range(accumulation_steps):
micro_batch = next(data_iterator) # 获取下一个微批次
# 计算此微批次的梯度(使用 JIT 编译的函数)
grads, loss = micro_batch_step(params, micro_batch)
# 累积梯度
accumulated_grads = jax.tree_util.tree_map(lambda acc, g: acc + g, accumulated_grads, grads)
total_loss += loss
# 3. 参数更新
# 平均梯度
averaged_grads = jax.tree_util.tree_map(lambda g: g / accumulation_steps, accumulated_grads)
average_loss = total_loss / accumulation_steps
# 应用优化器更新
# 此部分如果需要,通常可以单独进行 JIT 编译
@jax.jit
def apply_update(p, o_state, avg_grads):
updates, new_o_state = optimizer.update(avg_grads, o_state, p)
new_p = optax.apply_updates(p, updates)
return new_p, new_o_state
params, opt_state = apply_update(params, opt_state, averaged_grads)
metrics = {'loss': average_loss} # 报告微批次的平均损失
return params, opt_state, metrics
实际上,将其组织在更大的训练循环中涉及创建一个生成微批次的数据迭代器,并调用 accumulated_train_step。
使用 jax.lax.scan 的更函数式的方法可以将累积循环封装在一个 JIT 编译的函数中,但这需要仔细的状态管理和适当的数据加载结构。为了清晰起见,上面显示的显式循环通常更直接地阐明了此原理。
梯度累积过程包括初始化梯度,循环处理微批次以计算并累加梯度,对结果求平均,应用优化器更新,然后重置以进行下一个周期。
pmap 的交互梯度累积与使用 pmap 进行数据并行处理能很好地结合。使用 pmap 时,每个设备处理其自己的微批次一部分。梯度累积循环在每个设备上独立进行。
lax.pmean 这样的集体操作来对所有设备上的累积梯度进行平均。这确保了所有设备都基于来自整个有效批次()的梯度信息进行更新计算。这表示每个设备仍然只需要为其单个微批次的份额 占用内存,而用于梯度更新计算的有效批处理大小是 。
pmap 和可能与 Flax 或 Haiku 等框架的集成,需要仔细处理状态和数据流。梯度累积是当面临内存限制时,用于扩展模型规模的重要技术。它常与其他方法结合使用,如梯度检查点和混合精度,以训练先进模型。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•