如前所述,梯度检查点是一种有用的方法,可以在训练期间减少模型的内存占用。它通过避免存储网络中指定部分的中间激活来实现这一点。相反,这些激活在反向传播计算梯度时需要时,会重新计算。这以增加计算成本来换取内存使用量的减少,通常能训练比在加速器上直接能容纳的更大的模型。JAX 提供了 jax.checkpoint 作为实现梯度检查点功能的主要工具,它也可以写作 jax.remat。梯度检查点的工作原理考虑一个由两个子函数 $f_2 \circ f_1$ 组成的函数 $f$,即 $z = f(x) = f_2(f_1(x))$。令 $y = f_1(x)$。标准自动微分: 在前向传播过程中,$y$ 和 $z$ 都会被计算。$y$ 通常会保留在内存中,因为它在反向传播过程中是计算 $f_2$ 梯度所需。对 $f_1$ 进行检查点设置: 在前向传播过程中,$y = f_1(x)$ 被计算,可能立即被 $f_2$ 使用来计算 $z$,但之后 $y$ 可以被丢弃。在反向传播过程中,当梯度计算到达需要 $y$ 的点时,会使用保存的输入 $x$ 重新执行 $f_1(x)$,以便即时生成 $y$。这种重新计算避免了在整个前向和反向传播过程中存储像 $y$ 这样可能较大的中间张量。使用 jax.checkpointjax.checkpoint 函数是对你想要应用检查点功能的函数的一个包装。其基本用法是传入要进行检查点设置的函数:checkpointed_f1 = jax.checkpoint(f1) # 现在在你的模型中使用 checkpointed_f1 而不是 f1 y = checkpointed_f1(x) z = f2(y)当计算整个计算(涉及 z)的梯度时,JAX 的自动微分系统知道 checkpointed_f1 在反向传播时需要重新计算。示例:为计算块设置检查点我们来定义一个简单的操作序列,它可能代表一个大型神经网络中的一个块。我们将中间维度设大以模拟内存压力。import jax import jax.numpy as jnp import time # 定义一个计算块 def compute_block(x, W1, W2): """一个可能包含大型中间激活的块。""" y = jnp.dot(x, W1) y = jax.nn.gelu(y) # GELU 激活函数 # 'y' 是我们可能希望避免存储的中间激活 z = jnp.dot(y, W2) return z # 使用此块定义一个模拟损失函数 def loss_fn(x, W1, W2, targets): z = compute_block(x, W1, W2) # 简单的均方误差损失 loss = jnp.mean((z - targets)**2) return loss # 没有检查点的梯度函数 grad_fn_standard = jax.jit(jax.value_and_grad(loss_fn, argnums=(1, 2))) # --- 现在,定义检查点版本 --- # 对 compute_block 应用检查点 compute_block_checkpointed = jax.checkpoint(compute_block) # 使用检查点块定义损失函数 def loss_fn_checkpointed(x, W1, W2, targets): # 在这里使用检查点版本 z = compute_block_checkpointed(x, W1, W2) loss = jnp.mean((z - targets)**2) return loss # 包含检查点的梯度函数 grad_fn_checkpointed = jax.jit(jax.value_and_grad(loss_fn_checkpointed, argnums=(1, 2))) # --- 设置数据 --- key = jax.random.PRNGKey(42) batch_size = 64 input_dim = 512 hidden_dim = 8192 # 大型隐藏维度 output_dim = 512 key, x_key, w1_key, w2_key, t_key = jax.random.split(key, 5) x = jax.random.normal(x_key, (batch_size, input_dim)) W1 = jax.random.normal(w1_key, (input_dim, hidden_dim)) * 0.02 W2 = jax.random.normal(w2_key, (hidden_dim, output_dim)) * 0.02 targets = jax.random.normal(t_key, (batch_size, output_dim)) # --- 运行和比较 --- print("运行标准版本(编译 + 执行)...") start_time = time.time() loss_std, (dW1_std, dW2_std) = grad_fn_standard(x, W1, W2, targets) # 确保计算完成才停止计时 loss_std.block_until_ready() dW1_std.block_until_ready() dW2_std.block_until_ready() end_time = time.time() time_std = end_time - start_time print(f"标准损失: {loss_std:.4f}") print(f"标准时间: {time_std:.4f} 秒") print("\n运行检查点版本(编译 + 执行)...") start_time = time.time() loss_ckpt, (dW1_ckpt, dW2_ckpt) = grad_fn_checkpointed(x, W1, W2, targets) # 确保计算完成 loss_ckpt.block_until_ready() dW1_ckpt.block_until_ready() dW2_ckpt.block_until_ready() end_time = time.time() time_ckpt = end_time - start_time print(f"检查点损失: {loss_ckpt:.4f}") print(f"检查点时间: {time_ckpt:.4f} 秒") # 验证梯度是否接近(应该几乎相同) print("\n比较梯度...") print(f"W1 最大绝对差: {jnp.max(jnp.abs(dW1_std - dW1_ckpt)):.2e}") print(f"W2 最大绝对差: {jnp.max(jnp.abs(dW2_std - dW2_ckpt)):.2e}")结果分析内存: 虽然在这个简单脚本中我们无法直接测量峰值内存使用量,但核心区别在于 grad_fn_checkpointed 版本在正向传播期间不需要存储可能非常大的激活 y(大小为 batch_size * hidden_dim)以供反向传播后续使用。它在 W2 和 y 的反向梯度计算期间,使用 compute_block 重新计算了 y。如果 hidden_dim 很大,这种节省可能很可观。计算时间: 你可能会观察到检查点版本的执行时间更长。这是预期的权衡。反向传播现在包含了重新运行 compute_block 前向计算的成本。确切的时间差异很大程度上取决于前向计算与反向计算的相对成本以及所使用的硬件。梯度精度: 计算出的梯度在数值上应该非常接近。微小差异可能由于浮点算术变化而出现,特别是在使用混合精度时,但数学结果是相同的。下图说明了反向传播中的差异:digraph G { rankdir=LR; node [shape=box, style=rounded, fontname="sans-serif", color="#adb5bd", margin=0.1]; edge [fontname="sans-serif", color="#495057", fontsize=10]; subgraph cluster_forward { label = "前向传播(两个版本)"; style=dashed; color="#ced4da"; bgcolor="#f8f9fa"; F_X [label="输入 (X)", shape=ellipse, color="#74c0fc"]; F_L1 [label="点积(X, W1)\n+ GELU"]; F_Y [label="激活 (Y)", shape=ellipse, color="#ffc078"]; F_L2 [label="点积(Y, W2)"]; F_Z [label="输出 (Z)", shape=ellipse, color="#74c0fc"]; F_X -> F_L1; F_L1 -> F_Y; F_Y -> F_L2; F_L2 -> F_Z; } subgraph cluster_backward_std { label = "反向传播(标准)"; style=dashed; color="#ced4da"; bgcolor="#f8f9fa"; B_dZ_std [label="梯度 (dZ)", shape=ellipse, color="#f06595"]; B_dL2_std [label="W2 梯度\n(需要 Y)"]; B_dY_std [label="梯度 (dY)", shape=ellipse, color="#f06595"]; B_dL1_std [label="W1 梯度\n(需要 X)"]; B_dX_std [label="梯度 (dX)", shape=ellipse, color="#f06595"]; B_dZ_std -> B_dL2_std; B_dL2_std -> B_dY_std; B_dY_std -> B_dL1_std; B_dL1_std -> B_dX_std; // 内存读取依赖 F_Y -> B_dL2_std [style=dotted, arrowhead=odot, constraint=false, color="#ff922b", label=" 读取 Y\n (存储内存)"]; } subgraph cluster_backward_ckpt { label = "反向传播(检查点)"; style=dashed; color="#ced4da"; bgcolor="#f8f9fa"; B_dZ_ckpt [label="梯度 (dZ)", shape=ellipse, color="#f06595"]; B_dL2_ckpt [label="W2 梯度\n(需要 Y - 重新计算)"]; B_dY_ckpt [label="梯度 (dY)", shape=ellipse, color="#f06595"]; B_dL1_ckpt [label="W1 梯度\n(需要 X)"]; B_dX_ckpt [label="梯度 (dX)", shape=ellipse, color="#f06595"]; B_Recompute_L1 [label="重新计算 Y:\n点积(X, W1)+GELU", shape=box, style=filled, fillcolor="#b2f2bb"]; B_dZ_ckpt -> B_dL2_ckpt; F_X -> B_Recompute_L1 [style=dotted, arrowhead=odot, constraint=false, color="#1c7ed6", label=" 读取 X"]; B_Recompute_L1 -> B_dL2_ckpt [style=dotted, arrowhead=tee, label=" 使用重新计算的 Y"]; B_dL2_ckpt -> B_dY_ckpt; B_dY_ckpt -> B_dL1_ckpt; B_dL1_ckpt -> B_dX_ckpt; } }该图显示,标准反向传播读取存储的激活 'Y',而检查点反向传播则在计算第二层(点积(Y, W2))梯度之前,使用存储的输入 'X' 重新计算 'Y'。何时使用检查点梯度检查点在以下情况下最有效:中间激活较大: 产生大输出的层(例如,归约前的宽线性层,Transformer 中的自注意力机制)是好的选择。重新计算相对内存节省而言成本较低: 如果检查点部分的向前传播计算成本不高,相比于不存储其输出所节省的内存,这种权衡是有利的。受限于内存: 它主要是一种克服内存限制的工具,允许在固定内存预算内使用更大的模型或更大的批量大小。你可以有选择地将 jax.checkpoint 应用于模型中的特定层或块,这通常需要一些尝试来找到针对你的特定架构和硬件的内存节省与计算开销之间的最佳平衡。像 Flax 这样的框架提供了方便的包装器(例如 flax.linen.remat)来对特定模块应用检查点。这个实践练习展示了 jax.checkpoint 如何提供一种直接管理内存-计算权衡的方法,这是一种在 JAX 中有效训练大型模型的重要方法。