统计和系统异质性是许多联邦学习场景中固有的特点。当本地数据分布或设备能力差异很大时,单个全局模型,即使是使用改进的聚合方法训练的模型,也可能难以对每个客户端实现最佳表现。将联邦学习问题以多任务学习 (MTL) 的视角来看待,提供了一种专门应对此类情况的有效替代方法。
在 MTL 方法中,目标不是学习单个函数,而是同时学习多个相关函数。应用于联邦学习,我们可以将为每个客户端(或客户端组)学习个性化模型视为一个不同但相关的“任务”。每个客户端 k 都力求基于其本地数据 Dk 来最小化其自身的本地损失函数 Fk(wk)。MTL 方法不强制所有客户端都认同单个全局模型 w,而是寻求为 K 个客户端找到一组个性化模型 {w1,w2,...,wK}。
主要观点是,虽然这些任务(客户端模型)是个性化的,但它们并非完全独立。它们通常共享从集体数据中学习到的基本结构或特征。MTL 的公式运用了这种关联性,通常通过正则化来实现,从而允许模型相互借鉴优势。这避免了单个模型过度拟合其本地数据,特别是当客户端数据有限时,同时仍允许模型的特殊化。
将联邦学习视为多任务学习问题
将联邦学习视为 MTL 问题的一种常见方式是,通过一个正则化项 Ω 来增强本地客户端损失之和,该正则化项强制执行客户端模型之间的关系:
w1,...,wKmink=1∑KpkFk(wk)+λΩ(w1,...,wK)
此处,pk 代表客户端 k 的贡献权重(通常与其数据集大小成比例),Fk(wk) 是客户端 k 使用其个性化模型 wk 时的本地损失,而 λ 是一个超参数,控制正则化项 Ω 的强度。
正则化项 Ω 是任务关联性在此处编码的地方。存在多种形式,例如:
- 平均模型正则化: 惩罚每个客户端模型 wk 与中心模型(例如,平均模型 wˉ=∑kpkwk)之间的距离。这鼓励模型保持接近平均表示,同时允许存在偏差。
Ω(w1,...,wK)=k=1∑K∥wk−wˉ∥2
- 成对正则化: 惩罚模型对之间的差异,可能通过某种客户端相似度度量加权。
Ω(w1,...,wK)=j,k∑Sjk∥wj−wk∥2
其中 Sjk 可以编码有关客户端相似性的先验信息(例如,基于位置、设备类型或推断的数据分布)。
- 基于图的正则化: 如果客户端关系可以表示为图(例如,社交网络、地理邻近度),正则化可以强制图结构上的平滑性。
- 共享表示: 假设每个模型 wk 包含共享参数和任务专用参数。正则化可能仅作用于共享部分,或鼓励任务专用组件的相似性。
单个全局模型的标准联邦学习与客户端模型共享知识但保留本地组件的联邦多任务学习之间的比较。
示例算法:MOCHA
一种明确将联邦学习视为多任务学习的知名算法是 MOCHA(多任务联邦学习)。MOCHA 旨在解决上面所示的正则化目标函数。它通过运用其结构并采用联邦环境下的随机对偶坐标上升等技术来处理这个可能复杂的高维优化问题。
本质上,MOCHA 涉及客户端和服务器层面的迭代更新。客户端根据其数据和当前的正则化约束执行本地更新,而服务器则协调这些更新并管理与任务关系相关的信息(通常编码为对偶变量)。虽然优化的细节很精细,但核心思路是在最小化本地损失和满足连接模型的全局正则化约束之间找到平衡。
MTL 方法在联邦学习中的优势
- 直接个性化: 从设计上讲,MTL 方法为每个客户端学习个性化模型 wk,直接处理统计异质性。
- 有原则的知识共享: 正则化项提供了一种结构化的方式,供客户端共享知识并从集体数据中受益,避免了对纯本地模式的过度拟合。
- 灵活性: 不同的正则化项允许编码各种假设,关于客户端任务如何关联。
- 改进的本地性能: 对于异构数据集,MTL 公式通常会在客户端专用模型上产生更好的平均性能,与单个全局模型相比。
与其他个性化技术的关系
MTL 方法与本章讨论的其他个性化方法相关:
- 聚类联邦学习: 可以将聚类联邦学习视为 MTL 的一个具体实例,其中假定集群内的客户端共享相同的任务(模型),正则化强制执行这种分组。MTL 可以提供细致到单个客户端级别的个性化。
- 元学习(例如 Per-FedAvg): 元学习旨在找到一个好的模型初始化,可以快速适应每个客户端的数据,通过少量本地梯度步骤。MTL 直接优化最终的个性化模型 {wk},尽管在某些条件下目标函数在数学上可能相关。侧重点不同:元学习侧重于快速适应,而 MTL 侧重于相关任务的结构。
实施方面的考虑
虽然功能强大,MTL 框架引入了自身的考虑因素:
- 优化复杂性: 解决所有 wk 的联合优化问题可能比标准 FedAvg 的计算要求更高。MOCHA 等算法旨在以分布式方式处理此问题,但它们可能涉及更复杂的通信协议。
- 通信开销: 根据所使用的具体算法和正则化,通信可能不仅仅涉及模型权重或梯度(例如,对偶变量)。
- 超参数调整: 选择适当的正则化形式 Ω 和调整平衡参数 λ 对于良好的性能很重要,并需要仔细验证。
总之,将联邦学习视为一个多任务学习问题提供了一种有理论依据且有效的方法来处理客户端异质性并实现个性化。通过明确建模客户端任务之间的关系,MTL 算法可以学习专门的模型,这些模型运用共享知识,通常会带来在各种联邦网络中更好的性能。