虽然FedAvg为客户端更新聚合提供了一个简单的基础方法,但其性能在存在统计异构性(非独立同分布数据)时常会大幅下降。在差异很大的本地数据集上训练的客户端可能会将全局模型推向冲突的方向,导致收敛缓慢、振荡甚至发散。这种现象常被称为“客户端漂移”,其产生是由于每个客户端所追求的本地优化目标与全局目标之间存在差异。
SCAFFOLD(联邦学习中的随机控制平均)通过引入控制变量,提供了一个巧妙的解决办法来减轻这种客户端漂移。主要思路是估算每个客户端在拥有全局数据分布的情况下本会采取的更新方向,然后根据此估算值修正实际的本地更新。这可以降低客户端更新之间的方差,使它们与全局目标更一致,从而加速收敛。
SCAFFOLD的机制
SCAFFOLD通过在服务器和客户端上维护额外的状态信息,修改了标准的联邦学习过程:
- 服务器控制变量 (c): 表示联邦中所有数据平均梯度方向的估算值。
- 客户端控制变量 (ck): 每个客户端 k 维护自己的控制变量 ck,表示基于其本地数据的梯度方向估算值。
差值 c−ck 有效地反映了客户端 k 的“漂移”。SCAFFOLD使用这些控制变量来调整本地训练过程和聚合步骤。
算法概述
设 wt 为通信轮次 t 时的全局模型权重。
-
服务器广播: 服务器将其当前的全局模型 wt 和控制变量 ct 发送给选定的客户端。
-
客户端计算: 每个参与的客户端 k 执行以下步骤:
- 初始化其本地模型 wkt=wt。
- 接收服务器控制变量 ct。
- 执行 E 步本地随机梯度下降(SGD)。对于每个本地步骤 τ=0,...,E−1,使用本地小批量 b:
- 计算本地梯度:gk(wk,τt)=∇Fk(wk,τt;b),其中 Fk 是客户端 k 的本地损失函数。
- 使用校正后的梯度更新本地模型:
wk,τ+1t=wk,τt−ηl(gk(wk,τt)−ckt+ct)
此处,ηl 是本地学习率。请注意本地梯度 gk 是如何通过服务器控制变量 ct 与客户端控制变量 ckt 之间的差值进行调整的。这种调整旨在纠正客户端的本地漂移。
- 计算总的模型更新方向:Δwkt=wk,Et−wt。
- 更新客户端控制变量 ck。一种常见的方法是:
ckt+1=ckt−ct+Eηl1(wt−wk,Et)
此更新反映了在 E 步中观察到的平均本地梯度方向。
- 计算客户端控制变量的变化:Δckt=ckt+1−ckt。
- 将 Δwkt 和 Δckt 发送回服务器。
-
服务器聚合: 服务器从参与客户端集合 St 接收更新 (Δwkt,Δckt)。
- 聚合模型更新(类似于FedAvg):
Δwt=∣St∣1k∈St∑Δwkt
- 更新全局模型:
wt+1=wt+ηgΔwt
(ηg 是服务器学习率,通常设为1)。
- 聚合控制变量更新:
Δct=∣St∣1k∈St∑Δckt
- 更新服务器控制变量:
ct+1=ct+Δct
方差降低为何有效?
主要原因是,客户端更新中使用的项 (gk(wk,τt)−ckt+ct) 是对全局梯度 ∇F(wk,τt) 的更好估算,优于原始的本地梯度 gk(wk,τt)。通过减去估算的本地方向 (ckt) 并加上估算的全局方向 (ct),SCAFFOLD 有效引导客户端更新趋向全局最小值,减少了由不同本地数据分布引起的方差。
与FedAvg相比,这种方差降低带来了更稳定且通常更快的收敛,特别是在处理明显的统计异构性(非独立同分布数据)时。
FedAvg、FedProx和SCAFFOLD在模拟非独立同分布条件下的收敛行为。SCAFFOLD由于其方差降低原理,通常实现更快的收敛和可能更高的最终精度。
实现考量
与FedAvg相比,实现SCAFFOLD会带来一些额外开销:
- 状态: 服务器和客户端都需要存储各自的控制变量(c 和 ck)。这些变量通常与模型参数的大小相同。
- 通信: 客户端在每个轮次中需要同时将模型更新(Δwk)和控制变量更新(Δck)发送给服务器。服务器除了广播全局模型 w 之外,还需要广播全局控制变量 c。这实际上使每轮通信成本比FedAvg增加了一倍,尽管收敛所需的轮次数量上可能的减少通常可以弥补这一点。
尽管状态和通信要求有所增加,但SCAFFOLD有效应对客户端漂移的特性,使其成为提升在异构数据上运行的联邦学习系统性能和可靠性的有用的工具。其理论依据在非独立同分布设置下,比FedAvg给出更强的收敛保证。