趋近智
训练大型神经网络常遇到内存限制,尤其是在GPU和TPU等加速器上。在用于计算梯度的标准反向传播过程中,前向传播的中间激活必须被存储。对于非常深或宽的模型,存储这些激活所需的内存可能超出可用设备内存,从而完全停止训练过程。
梯度检查点,也称激活检查点或再物化,是一种专门为缓解此内存瓶颈而设计的技术。它的主要设想简单而巧妙:并非存储前向传播的所有中间激活,而是有策略地仅保留一部分。然后,在反向传播过程中,每当计算梯度需要一个未存储的激活时,我们从最近的已存储激活开始,即时重新计算它。
这带来了直接的权衡:
将深度网络想象成一系列层或计算块。
jax.checkpoint(或 jax.remat)JAX 提供了一个便捷的转换,jax.checkpoint(它是名称更具描述性的 jax.remat 的别名,remat 是 re-materialization 的缩写),用于实现梯度检查点。您可以将其作为装饰器应用于函数,或封装计算的特定部分。
import jax
import jax.numpy as jnp
# 定义一个可能计算量很大的计算块
def compute_intensive_block(x, params):
# 代表多层或复杂操作
x = jnp.dot(x, params['w1']) + params['b1']
x = jax.nn.relu(x)
x = jnp.dot(x, params['w2']) + params['b2']
return x
# 对此块应用检查点
checkpointed_block = jax.checkpoint(compute_intensive_block)
# 在更大模型上下文中的示例用法(简化版)
def model(x, all_params):
# ... 初始层 ...
intermediate_output = x # 前面层的输出
# 应用检查点块
# compute_intensive_block *内部*的激活将不会被存储
# (除非它们是该块的最终输出)
x = checkpointed_block(intermediate_output, all_params['block_params'])
# ... 后续层 ...
final_output = x # 示例最终层
return final_output
# 然后你可以像往常一样对 'model' 函数进行求导
grad_fn = jax.grad(lambda p, data: jnp.sum(model(data, p)))
# 虚拟数据和参数
key = jax.random.PRNGKey(0)
dummy_x = jnp.ones((1, 128))
dummy_params = {
'block_params': {
'w1': jax.random.normal(key, (128, 512)),
'b1': jnp.zeros(512),
'w2': jax.random.normal(key, (512, 128)),
'b2': jnp.zeros(128)
}
# ... 其他参数 ...
}
# 计算梯度 - 检查点在 grad_fn 内部处于活动状态
gradients = grad_fn(dummy_params, dummy_x)
print("梯度计算成功。")
当 jax.grad 应用于包含 jax.checkpoint 的函数时,JAX 的自动微分机制会理解检查点函数内的中间结果在反向传播期间不可用,需要重新计算。它智能地管理这个再物化过程。
有效应用 jax.checkpoint 涉及一些策略性决定:
jit 的配合: jax.checkpoint 与 jax.jit 顺畅集成。JAX 将高效地编译原始函数和重新计算逻辑。梯度检查点允许您用计算时间换取更低的内存使用。这对于训练那些否则无法适应可用硬件的模型通常是必不可少的。
使用梯度检查点时的示意性权衡。实际百分比因模型架构和检查点策略而异,但内存使用通常会大幅减少,而计算时间则适度增加。
您应在以下情况考虑使用 jax.checkpoint:
尽管梯度检查点增加了计算开销,但它是大规模训练工具箱中的一项有力技术,使得在内存受限下原本无法训练的先进模型变得可能。它能与分布式训练(如 pmap)和混合精度等其他技术有效结合,进一步拓展模型规模的上限。
这部分内容有帮助吗?
jax.checkpoint(别名为jax.remat)转换的官方文档,提供了在JAX中进行内存优化的实际使用细节和示例。© 2026 ApX Machine Learning用心打造