统计异质性意味着通过标准联邦方法学习到的单一全局模型,在各个客户端的特定数据分布上可能表现不佳。尽管像FedProx或SCAFFOLD这样的方法旨在异质性条件下提高全局模型的收敛性,但它们并未明确为每个客户端定制模型。个性化技术直接解决了这个问题,其目标是提供适应各个客户端需求的模型。元学习提供了一个有效框架,可在联邦学习环境中实现这种个性化。
元学习的中心思想,常被称为“学习如何学习”,是在各种学习任务上训练一个模型,使其只需少量数据和计算即可解决新的、以前未见过的学习任务。在联邦学习的背景下,每个客户端的本地数据集可以被视为一个独立的学习任务。因此,联邦元学习的目标是学习一个全局模型(或模型初始化),使其能够快速适应并高效应用于每个客户端的独特数据分布,且只需极少的本地微调。
这种方法与标准FedAvg相比,改变了核心目标。它不再是寻求一个单一模型 w 来最小化所有客户端的平均损失:
wminF(w)=k=1∑NpkFk(w)
pk 代表客户端 k 的权重,Fk(w) 是其本地损失。相反,元学习旨在找到一个模型初始化 w,使得在特定客户端 k 的数据上进行少量梯度步骤后,能够大大提升个性化模型 wk 的性能。
元学习在联邦学习中的工作方式
大多数联邦元学习方法在每个通信轮次中都包含两阶段过程:
- 本地适应: 每个参与客户端接收当前的全局元模型 w。客户端随后使用其本地数据执行一次或多次梯度下降步骤来适应此模型。这会生成针对客户端 k 特定任务优化的个性化模型 wk。
- 元优化: 客户端根据其本地适应过程向服务器返回信息。此信息用于更新全局元模型 w。主要不同之处在于此元更新的计算方式,其目标是提高 w 的适应能力,而不仅仅是其平均性能。
我们来看两个主要例子:
Reptile算法适应
Reptile是一种简单但有效的,易于适应联邦学习的一阶元学习算法。在其联邦形式中:
- 服务器广播: 服务器将当前全局模型 w 发送给选定的客户端。
- 客户端本地适应: 每个客户端 k 从 w 开始,使用其自身数据 Dk 执行 τ 步本地SGD(或另一个优化器),以获得适应后的模型 wk(τ)。
wk(0)=w对于 t=1 到 τ:wk(t)=wk(t−1)−α∇Fk(wk(t−1))
α 代表本地学习率。
- 更新通信: 每个客户端将其最终适应模型与初始模型之间的差异,即 Δk=wk(τ)−w,发送回服务器。(注意:变体可能直接发送 wk(τ))。
- 服务器元更新: 服务器聚合这些差异(或适应后的模型)以更新全局元模型。一种简单的方法是:
w←w+βK1k∈S∑(wk(τ)−w)
S 代表选定客户端的集合,K=∣S∣,β 是服务器学习率(通常设为 1)。
直观上,Reptile将全局模型 w 推向参数空间中的一个点,从该点开始,多个适应步骤(本地SGD)能够持续地在不同客户端(任务)上带来良好性能。
Reptile在联邦学习中一轮的流程图。客户端在贡献元模型更新之前执行多个本地步骤。
MAML在联邦学习中的变体
模型无关元学习 (MAML) 明确地优化适应后的性能。其中心思想是找到一个初始化 w,使得在任何客户端 k 的数据上进行一次梯度步骤后,能够大幅提升 Fk 的性能。由于MAML的二阶性质(需要Hessian-向量积)或需要在适应步骤之后传输评估的梯度,将其直接应用于联邦学习可能具有挑战性。
几种联邦学习算法从MAML中获得启发,例如个性化FedAvg (Per-FedAvg)。Per-FedAvg使用一阶方法近似MAML的更新目标。一种常见方法包括:
- 服务器广播: 服务器发送当前全局模型 w。
- 客户端内部更新(适应): 客户端 k 使用其数据 Dk 上的一步(或少量)梯度步骤计算适应后的模型 wk′:
wk′=w−α∇Fk(w)
- 客户端外部更新(元梯度): 客户端随后计算其损失相对于原始模型 w 的梯度,该梯度是使用适应后的模型 wk′ 评估的。这个梯度 ∇Fk(wk′) 本质上反映了改变初始 w 会如何在一次适应步骤之后影响损失。这个梯度被发送到服务器。
- 服务器元更新: 服务器聚合这些元梯度以更新全局模型 w:
w←w−βK1k∈S∑∇Fk(wk′)
这个过程旨在找到一个 w,其位置使得本地适应步骤(w→wk′)在所有客户端上都特别有效。
优点与考量
- 个性化提升: 元学习直接优化在本地适应后表现良好的模型,与简单地微调标准FedAvg模型相比,可能带来更好的个性化性能。
- 处理异质性: 它特别适合客户端任务大相径庭的高度异质(非独立同分布)环境。
- 适应效率: 旨在生成只需少量本地数据样本或梯度步骤即可有效实现个性化的模型。
然而,也存在一些权衡:
- 本地计算增加: 与固定本地训练轮次的标准FedAvg相比,客户端通常需要在每轮中执行更多计算(例如,Reptile中的多步SGD,或MAML变体中的内外部更新计算)。
- 通信: 通信成本取决于具体算法。Reptile可能通信模型差异,而MAML变体可能通信梯度,其大小可能与FedAvg梯度相似。
- 收敛性分析: 联邦元学习算法的收敛行为比标准联邦优化更复杂,难以分析。
联邦元学习代表了一种高级的个性化方法,直接应对了在多样化联邦网络中“一刀切”模型常不敷使用的挑战。通过学习一个可适应的初始化,它使得为每个客户端高效创建专用模型成为可能,在统计异质性面前大大提升了实用性。