随机梯度下降 (SGD) 是处理大规模机器学习任务的常用方法。但它依赖于从小批量数据计算出的噪声梯度,可能导致收敛缓慢或不稳定。每一步仅采样一小部分数据引入的方差,通常需要细致调整学习率和更新策略。
随机平均梯度 (SAG) 算法提供了一种有吸引力的替代方案,直接旨在降低这种方差。SAG 由 Schmidt、Le Roux 和 Bach 于 2012 年提出,它试图将 SGD 的低迭代成本与全批量梯度下降常有的更快收敛特性结合起来,特别是对于强凸问题。
主要思想:平均过去的梯度
SAG 背后的理念简单而巧妙:与其只使用当前小批量(或最简单情况下的单个数据点)的梯度,不如借鉴先前迭代中计算出的梯度信息?SAG 会为训练集中的每个独立数据点,记住最近计算出的梯度。
设目标函数是 N 个数据点的平均损失:
L(θ)=N1∑i=1NLi(θ)
其中 Li(θ) 是与第 i 个数据点关联的损失,θ 表示模型参数。
全批量梯度下降使用梯度 ∇L(θ)=N1∑i=1N∇Li(θ)。SGD 最简单形式(批量大小为 1)在迭代 k 时选择一个随机索引 ik,并使用 ∇Lik(θk) 进行更新。
SAG 的工作方式不同。它维护一个包含 N 个梯度向量的表格,g1,g2,…,gN。每个 gi 存储了当数据点 i 在过去某个迭代中被选中时,为其计算出的最近梯度 ∇Li(θ)。
SAG 更新规则
在每次迭代 k+1 时:
- 采样: 从 {1,2,…,N} 中随机选择一个索引 ik。
- 计算当前梯度: 使用当前参数 θk 计算所选数据点的梯度:vik=∇Lik(θk)。
- 更新平均梯度估计: SAG 的核心是使用存储梯度的平均值作为更新方向。这个平均值可以高效地更新。设 Gk=N1∑j=1Ngj 是存储到迭代 k 的梯度平均值。新的平均值 Gk+1 通过在求和中用新计算的梯度 vik 有效地替换旧的存储梯度 gik 来计算:
Gk+1=Gk+N1(vik−gik)
这避免了在每一步重新求和所有 N 个存储的梯度。
- 更新参数: 使用这个噪声较少的平均梯度估计更新参数:
θk+1=θk−ηGk+1
其中 η 是学习率。
- 更新记忆: 将新计算的梯度 vik 存储到记忆表格中,替换掉 gik 的先前值:
gik←vik
最初,存储的梯度 gi 可以初始化为零向量或根据初始参数猜测计算。
单次 SAG 迭代的流程。采样一个数据点 ik,计算其当前梯度,这个新梯度同时更新运行中的平均梯度估计 (Gk+1) 和存储的梯度记忆 (gik)。
收敛性质
SAG 的主要优点之一是其收敛速度。对于强凸且光滑的目标函数,SAG 实现了线性收敛速度。这相比于标准 SGD 在这类问题上通常的次线性速度是一个显著提升。简单来说,SAG 在达到一定精度所需的迭代次数方面收敛更快,接近全批量梯度下降的理想特性。与每迭代成本随 N 线性增加的全批量梯度下降不同,SAG 的收敛速度不会随着数据集大小 N 的增加而显著下降。
计算成本与内存
- 计算: SAG 的每次迭代成本主要由单个梯度 ∇Lik(θk) 的计算以及向量加法/减法决定。这与 SGD(批量大小为 1)的每次迭代成本相当。它比全批量梯度下降便宜得多,后者每迭代需要计算 N 个梯度。
- 内存: SAG 的主要缺点是其内存需求。它需要存储 N 个梯度向量,其中 N 是训练集中的数据点总数。如果参数维度 d 很大且 N 数量庞大(数十亿样本),存储 N×d 个浮点数可能变得不切实际甚至不可能。
SAG 与 SGD 对比
- 方差: SAG 通过平均梯度显著降低了相比 SGD 的方差。这带来更稳定的更新和更快的收敛。
- 内存: SGD 的内存开销最小(仅存储参数和当前小批量梯度),而 SAG 需要存储 N 个梯度。
- 收敛速度(迭代次数): 对于强凸问题,SAG 通常比 SGD(次线性)收敛快得多(线性)。
- 实现复杂度: 由于需要管理梯度记忆表,SAG 的实现略微复杂。
SAG 代表了弥合 SGD 可扩展性与批量方法快速收敛之间差距的重要一步。然而,其内存需求限制了它在 N 极大情景下的适用性。这一限制推动了后续方差削减技术(如 SVRG,我们接下来会讨论)的发展,旨在保留快速收敛优势的同时减轻内存负担。