基础模型上的元学习通常占用大量内存,主要原因有两点:在反向传播计算梯度时需要存储中间激活值,以及需要为可能数十亿个参数维护优化器状态。当处理元梯度,特别是MAML等方法所需的二阶导数时,这些需求会成倍增加,迅速超出可用硬件内存。幸运的是,有多种优化技术可以显著减轻这些内存压力,使大规模元学习变得可行。梯度检查点(激活值重计算)标准的反向传播要求在正向传播过程中计算的所有中间激活值都存储起来,以便在反向传播过程中计算梯度。对于像基础模型这样的深度网络,尤其是在元学习的展开计算图(跨越多个内循环步骤)中,这些激活值所需的内存会成为一个主要瓶颈。梯度检查点,也称为激活值重计算,提供了一个直接的权衡:减少内存消耗,但会增加计算时间。检查点技术不是存储所有激活值,而是策略性地选择某些激活值进行保存(即“检查点”),而丢弃其他激活值。在反向传播过程中,当计算梯度需要某个已丢弃的激活值时,模型中从上一个检查点到所需激活值之间的部分会即时重新计算。工作原理: 假设正向传播是一系列函数 $f_1, f_2, ..., f_n$。如果没有检查点,所有中间输出 $x_1 = f_1(x_0), x_2 = f_2(x_1), ..., x_n = f_n(x_{n-1})$ 都会被存储。使用检查点时,我们可能只存储 $x_0$ 和 $x_{n/2}$。为了计算 $f_{n/2+1}$ 处的梯度,我们从 $x_{n/2}$ 重新运行正向传播来计算 $x_{n/2+1}$,将其用于梯度计算,然后再次丢弃。digraph G {rankdir=LR;node[shape=box,style=rounded,fontname="sans-serif",fontsize=10,fillcolor="#dee2e6",color="#495057"];edge[color="#868e96"];a0[label="输入"];a1[label="层 1 激活值"];a2[label="层 2 激活值"];a3[label="层 3 激活值"];a4[label="输出"];b4[label="输出梯度"];b3[label="层 3 梯度"];b2[label="层 2 梯度"];b1[label="层 1 梯度"];b0[label="输入梯度"];a0->a1->a2->a3->a4;a4->b4[style=dashed];b4->b3[style=dashed];b3->b2[style=dashed];b2->b1[style=dashed];b1->b0[style=dashed];c0[label="输入",fillcolor="#96f2d7"];c1[label="层 1 激活值",fillcolor="#96f2d7"];c2[label="层 2 激活值",fillcolor="#96f2d7"];c3[label="层 3 激活值",fillcolor="#96f2d7"];c4[label="输出",fillcolor="#96f2d7"];d4[label="输出梯度"];d3[label="层 3 梯度"];d2[label="层 2 梯度"];d1[label="层 1 梯度"];d0[label="输入梯度"];c0->c1->c2->c3->c4;c4->d4[style=dashed];d4->d3[style=dashed,color="#fa5252"];d3->d2[style=dashed];d2->d1[style=dashed,color="#fa5252"];d1->d0[style=dashed];label="反向传播与检查点对比";fontsize=12;fontname="sans-serif";} 标准反向传播内存使用与梯度检查点的比较。检查点通过在反向传播过程中重新计算(红色虚线)来避免存储中间激活值(如c1、c3)。与元学习的关联: 梯度检查点可以应用于内循环更新(任务适应)和/或外循环元更新。它对于计算图包含内部优化过程的MAML及其变体尤其有效。通过对内循环的片段或基础模型内的层进行检查点,计算可能复杂的元梯度(如 $\nabla_{\theta} \mathcal{L}_{\text{meta}}$)时的峰值内存使用量可以显著减少。这种权衡是训练时间大约增加20-30%,具体取决于检查点策略和模型架构,但它通常使得原本因内存限制而无法实现的训练配置成为可能。混合精度训练另一种高效技术是混合精度训练。混合精度训练不是将所有计算和存储的所有值(权重、激活值、梯度)都使用标准32位浮点精度(FP32),而是对许多操作使用诸如16位浮点(FP16)或Brain浮点(BF16)等低精度格式。优点:内存减少: 存储FP16或BF16值所需的内存是FP32的一半。这适用于反向传播中存储的激活值、反向传播过程中计算的梯度,以及可能的模型权重本身。优化器状态也受益于此,我们将在后面讨论。计算加速: 现代硬件(如带有Tensor Cores的NVIDIA GPU)为FP16或BF16执行的矩阵乘法和卷积提供了大幅加速。挑战与解决办法: 与FP32相比,FP16的动态范围有限,这使得它在训练过程中,尤其是在大型模型中,容易出现数值下溢(梯度变为零)或上溢(梯度变为无穷大)。BF16提供了更宽的动态范围(与FP32相似),但精度低于FP16,通常为训练大型Transformer模型提供更好的平衡。对抗FP16中下溢/上溢的主要技术是损失缩放。在反向传播开始之前,损失值会乘以一个缩放因子。这会将梯度放大,使其进入FP16的可表示范围。在优化器更新权重之前,梯度会按相同的因子缩放回来。这个缩放因子可以是固定的(静态损失缩放),也可以在训练过程中动态调整(动态损失缩放),以找到一个既能避免上溢又能最小化下溢的理想值。与元学习的关联: 混合精度可以应用于整个元学习过程。内循环适应步骤和外循环元更新都可以采用低精度计算和存储。激活值: 将激活值存储为FP16/BF16会直接将其内存成本减半。梯度: 以较低精度计算和存储梯度可以减少内存。元梯度: 可能需要特别注意。虽然BF16通常可靠,但带有动态损失缩放的FP16对于确保元梯度(其具有复杂的依赖性和变化的量级)的准确计算是必要的。由于BF16在训练过程中具有更好的稳定性,因此在大型语言模型中通常优先使用它,尽管具有有效损失缩放的FP16也可以工作,并且在某些硬件上可能会提供略快的计算速度。高效优化器像Adam或AdamW这样的标准优化器,除了模型参数之外,还会维护额外的状态。例如,Adam会为每个参数存储一阶矩(动量)和二阶矩(方差)的估计值。对于一个参数数量为 $N$ 且以FP32存储的模型,优化器状态通常需要额外 $2 \times N \times 4$ 字节的内存,这实际上使得仅参数及其优化状态所需的内存增加了两倍。当 $N$ 达到数十亿时,这会成为巨大的内存负担,特别是对于作用于整个基础模型参数 $\theta$ 的元优化器而言。内存高效的优化器可以减轻这一负担:Adafactor: Adafactor 最初是为大型NLP模型提出的,它避免为每个参数存储完整的二阶矩估计。相反,它只为权重矩阵维护平方梯度的行和列和,从而有效地存储了二阶矩张量的因式分解表示。这使得大型层中二阶矩状态的内存占用从 $O(N)$ 减少到 $O(\sqrt{N})$,提供了大幅节省,同时没有显著损害性能。它通常也不存储一阶矩,而是依赖于动量衰减因子。8位优化器: bitsandbytes 等库实现了将优化器状态(动量和方差)量化为8位整数的优化器(例如8位Adam)。它们不是为每个状态条目存储32位浮点数,而是存储一个8位整数加上在执行参数更新前对状态进行反量化所需的量化统计数据(如块级缩放因子)。这可以将优化器状态内存减少约75%(例如,从FP32状态的每个参数8字节减少到约2字节/参数)。其他方法(例如Sophia、Lion): 研究持续推出目标是实现类似Adam性能但内存或计算开销更低的优化器,尽管它们在元学习大型模型中的广泛应用和稳定性仍在发展中。与元学习的关联: 虽然内循环优化器可能作用于较小的参数集(例如,如果在适应过程中使用LoRA等PEFT方法),但更新基础模型参数 $\theta$ 的元优化器作用于完整的参数集。将Adafactor或8位Adam应用于此外循环优化步骤可以大幅减少优化器状态所需的内存,从而释放重要资源。{"layout": {"title": "每参数相对内存使用构成", "xaxis": {"title": "训练设置"}, "yaxis": {"title": "相对内存单位", "range": [0, 6]}, "barmode": "stack", "legend": {"traceorder": "reversed"}, "font": {"family": "sans-serif"}}, "data": [{"type": "bar", "name": "优化器状态 (FP32)", "x": ["基线 (FP32)", "混合精度 (FP16)", "高效优化器 (8位)", "组合"], "y": [2, 1, 0.5, 0.5], "marker": {"color": "#ffc9c9"}}, {"type": "bar", "name": "梯度", "x": ["基线 (FP32)", "混合精度 (FP16)", "高效优化器 (8位)", "组合"], "y": [1, 0.5, 1, 0.5], "marker": {"color": "#a5d8ff"}}, {"type": "bar", "name": "激活值 (峰值)", "x": ["基线 (FP32)", "混合精度 (FP16)", "高效优化器 (8位)", "组合"], "y": [2, 1, 2, 0.5], "marker": {"color": "#96f2d7"}, "text": ["(完整)", "(一半)", "(完整,检查点)", "(一半,检查点)"]}, {"type": "bar", "name": "参数", "x": ["基线 (FP32)", "混合精度 (FP16)", "高效优化器 (8位)", "组合"], "y": [1, 1, 1, 1], "marker": {"color": "#ced4da"}}]}不同优化策略下每参数的相对内存贡献。“激活值”反映了峰值使用量,在“组合”策略中使用的梯度检查点,即使在“高效优化器”单独使用的情况下激活值以FP32存储,也能大幅减少内存。混合精度将激活值和梯度内存减半。8位优化器大幅减少优化器状态内存。结合使用这些技术可获得最大的节省。通过策略性地结合梯度检查点、混合精度训练和内存高效优化器,即使在大型基础模型上,管理元学习的内存需求也变得可行,从而使得这些高效适应技术得以大规模应用。