联邦学习(FL)代表了传统集中式机器学习的一个巨大转变。它不将大量用户数据收集到中央存储库中进行模型训练,而是允许在分布式设备上或在独立的组织内部直接协作构建模型,同时将原始数据保留在本地。这种方式自然地提升了数据隐私性,因为敏感信息从未离开过客户端的掌控。支持这种分布式学习方法的核心原则在此回顾。
联邦学习工作流程
联邦学习最常用且最根本的方法是联邦平均(FedAvg)。它遵循一个由中央服务器协调的迭代过程,通常涉及以下步骤:
- 初始化: 服务器开始时有一个初始全局模型,通常是随机初始化或在公共数据集上预训练过的模型。
- 客户端选择: 在每个通信轮次 t 中,服务器选择可用客户端(例如,移动设备、医院)的一个子集来参与训练。选择策略可能不同,通常涉及随机抽样。
- 模型分发: 服务器将当前的全局模型参数 wt 传输给选定的客户端。
- 本地训练: 每个选定的客户端 k 使用其本地数据 Dk 更新收到的模型。这通常涉及在其本地目标函数 Fk(w) 上运行多步梯度下降(或其变体),该函数从其私有数据导出。设 wt+1k 为客户端 k 上更新后的本地模型参数。
- 模型更新传输: 客户端将其计算出的更新发送回服务器。这可以是完整的更新模型参数 wt+1k,或者更常见的是差值 Δk=wt+1k−wt,或者梯度 ∇Fk(wt)。这种传输可能是潜在的通信瓶颈和隐私风险区域。
- 聚合: 服务器聚合来自参与客户端的更新,以计算新的全局模型 wt+1。在联邦平均中,这通常是根据每个客户端用于训练的数据量进行的加权平均:
wt+1=k∈St∑nnkwt+1k
其中 St 是第 t 轮中选定客户端的集合,nk=∣Dk∣ 是客户端 k 上的数据点数量,而 n=∑k∈Stnk 是所有选定客户端的数据点总数。或者,如果发送更新 Δk:
wt+1=wt+k∈St∑nnkΔk
- 迭代: 该过程从步骤2开始重复,直到达到预设的通信轮次或满足收敛条件。
这种循环过程使得全局模型能够从分布式数据集中包含的集体知识中学习,而无需将数据集中化。
一张图表,说明了涉及服务器和代表性客户端的标准同步联邦学习周期。
实体:服务器与客户端
联邦学习生态系统主要包含两种类型的实体:
- 客户端: 这些是持有本地数据的设备或组织(例如,智能手机、笔记本电脑、医院、银行)。它们拥有计算资源,可以根据服务器的指令执行本地模型训练。它们的参与可能是间歇性的,其资源(CPU、网络带宽、数据量)可能差异很大,从而导致系统异构性。
- 服务器: 这个中心实体协调学习过程。它初始化模型、选择客户端、分发模型、聚合更新并维护全局状态。虽然服务器不访问原始客户端数据,但其作用对协调和收敛非常重要。服务器本身可能是潜在的单点故障或攻击目标。
优化目标
正如章节引言中提到的,首要目标通常是最小化全局目标函数 F(w),它代表了所有 N 个客户端的聚合性能:
wminF(w)其中F(w)=k=1∑NpkFk(w)
这里,Fk(w)=L(w;Dk) 是客户端 k 在其本地数据集 Dk 上计算的本地损失函数,而 pk 是分配给客户端 k 的权重,通常为 pk=nk/∑j=1Nnj,其中 nk=∣Dk∣。这种表述表明,我们的目标是获得一个在所有客户端数据分布上平均表现良好的模型。
与其他学习模式的对比
区分联邦学习与其他方法很重要:
- 集中式学习: 要求将所有数据汇集到一个位置,由于隐私、通信或法规限制,联邦学习明确避免了这一点。
- 经典分布式学习: 通常假设数据分布在集群中的节点上(例如,使用参数服务器),但通常是在单个可信域内,节点能力更同质,且数据分区通常是IID(独立同分布)的。联邦学习特别针对具有更高异构性、不可信环境(可能)和非IID数据的场景。
"对联邦学习基本原则和标准联邦平均工作流程的这份回顾,为后续内容做了铺垫。尽管直接明了,但这个基础模型在几个简化假设下运行。联邦环境带来了数据和系统异构性、隐私漏洞、通信成本以及潜在对抗行为方面的重要挑战,这些挑战推动了我们将在本课程中研究的高级技术。"