趋近智
如前所述,梯度检查点是一种有用的方法,可以在训练期间减少模型的内存占用。它通过避免存储网络中指定部分的中间激活来实现这一点。相反,这些激活在反向传播计算梯度时需要时,会重新计算。这以增加计算成本来换取内存使用量的减少,通常能训练比在加速器上直接能容纳的更大的模型。
JAX 提供了 jax.checkpoint 作为实现梯度检查点功能的主要工具,它也可以写作 jax.remat。
考虑一个由两个子函数 f2∘f1 组成的函数 f,即 z=f(x)=f2(f1(x))。令 y=f1(x)。
这种重新计算避免了在整个前向和反向传播过程中存储像 y 这样可能较大的中间张量。
jax.checkpointjax.checkpoint 函数是对你想要应用检查点功能的函数的一个包装。其基本用法是传入要进行检查点设置的函数:
checkpointed_f1 = jax.checkpoint(f1)
# 现在在你的模型中使用 checkpointed_f1 而不是 f1
y = checkpointed_f1(x)
z = f2(y)
当计算整个计算(涉及 z)的梯度时,JAX 的自动微分系统知道 checkpointed_f1 在反向传播时需要重新计算。
我们来定义一个简单的操作序列,它可能代表一个大型神经网络中的一个块。我们将中间维度设大以模拟内存压力。
import jax
import jax.numpy as jnp
import time
# 定义一个计算块
def compute_block(x, W1, W2):
"""一个可能包含大型中间激活的块。"""
y = jnp.dot(x, W1)
y = jax.nn.gelu(y) # GELU 激活函数
# 'y' 是我们可能希望避免存储的中间激活
z = jnp.dot(y, W2)
return z
# 使用此块定义一个模拟损失函数
def loss_fn(x, W1, W2, targets):
z = compute_block(x, W1, W2)
# 简单的均方误差损失
loss = jnp.mean((z - targets)**2)
return loss
# 没有检查点的梯度函数
grad_fn_standard = jax.jit(jax.value_and_grad(loss_fn, argnums=(1, 2)))
# --- 现在,定义检查点版本 ---
# 对 compute_block 应用检查点
compute_block_checkpointed = jax.checkpoint(compute_block)
# 使用检查点块定义损失函数
def loss_fn_checkpointed(x, W1, W2, targets):
# 在这里使用检查点版本
z = compute_block_checkpointed(x, W1, W2)
loss = jnp.mean((z - targets)**2)
return loss
# 包含检查点的梯度函数
grad_fn_checkpointed = jax.jit(jax.value_and_grad(loss_fn_checkpointed, argnums=(1, 2)))
# --- 设置数据 ---
key = jax.random.PRNGKey(42)
batch_size = 64
input_dim = 512
hidden_dim = 8192 # 大型隐藏维度
output_dim = 512
key, x_key, w1_key, w2_key, t_key = jax.random.split(key, 5)
x = jax.random.normal(x_key, (batch_size, input_dim))
W1 = jax.random.normal(w1_key, (input_dim, hidden_dim)) * 0.02
W2 = jax.random.normal(w2_key, (hidden_dim, output_dim)) * 0.02
targets = jax.random.normal(t_key, (batch_size, output_dim))
# --- 运行和比较 ---
print("运行标准版本(编译 + 执行)...")
start_time = time.time()
loss_std, (dW1_std, dW2_std) = grad_fn_standard(x, W1, W2, targets)
# 确保计算完成才停止计时
loss_std.block_until_ready()
dW1_std.block_until_ready()
dW2_std.block_until_ready()
end_time = time.time()
time_std = end_time - start_time
print(f"标准损失: {loss_std:.4f}")
print(f"标准时间: {time_std:.4f} 秒")
print("\n运行检查点版本(编译 + 执行)...")
start_time = time.time()
loss_ckpt, (dW1_ckpt, dW2_ckpt) = grad_fn_checkpointed(x, W1, W2, targets)
# 确保计算完成
loss_ckpt.block_until_ready()
dW1_ckpt.block_until_ready()
dW2_ckpt.block_until_ready()
end_time = time.time()
time_ckpt = end_time - start_time
print(f"检查点损失: {loss_ckpt:.4f}")
print(f"检查点时间: {time_ckpt:.4f} 秒")
# 验证梯度是否接近(应该几乎相同)
print("\n比较梯度...")
print(f"W1 最大绝对差: {jnp.max(jnp.abs(dW1_std - dW1_ckpt)):.2e}")
print(f"W2 最大绝对差: {jnp.max(jnp.abs(dW2_std - dW2_ckpt)):.2e}")
grad_fn_checkpointed 版本在正向传播期间不需要存储可能非常大的激活 y(大小为 batch_size * hidden_dim)以供反向传播后续使用。它在 W2 和 y 的反向梯度计算期间,使用 compute_block 重新计算了 y。如果 hidden_dim 很大,这种节省可能很可观。compute_block 前向计算的成本。确切的时间差异很大程度上取决于前向计算与反向计算的相对成本以及所使用的硬件。下图说明了反向传播中的差异:
该图显示,标准反向传播读取存储的激活 'Y',而检查点反向传播则在计算第二层(点积(Y, W2))梯度之前,使用存储的输入 'X' 重新计算 'Y'。
梯度检查点在以下情况下最有效:
你可以有选择地将 jax.checkpoint 应用于模型中的特定层或块,这通常需要一些尝试来找到针对你的特定架构和硬件的内存节省与计算开销之间的最佳平衡。像 Flax 这样的框架提供了方便的包装器(例如 flax.linen.remat)来对特定模块应用检查点。
这个实践练习展示了 jax.checkpoint 如何提供一种直接管理内存-计算权衡的方法,这是一种在 JAX 中有效训练大型模型的重要方法。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造