在回顾了联邦学习的基本工作流程及其固有的难题,尤其是数据和系统异构性之后,我们现在建立一个更严谨的数学框架。准确定义优化问题对于理解、分析和设计高级联邦算法非常重要。
其核心在于,联邦学习旨在训练一个单一的全局模型,使用分散在多个客户端上的数据,而无需将数据集中化。目标通常被表述为最小化一个全局损失函数,它表示对各个损失函数的汇总,这些损失函数在每个客户端的本地数据上计算得到。
全局目标函数
联邦优化中的标准目标是找到模型参数 w,使得全局目标函数 F(w) 最小化。该函数通常定义为每个客户端 k 的局部目标函数 Fk(w) 的加权平均值:
F(w)=k=1∑NpkFk(w)
让我们分解此方程的组成部分:
- w: 表示机器学习模型的参数(例如,神经网络的权重和偏置),我们旨在对其进行优化。这是所有客户端共享的全局模型。
- N: 参与联邦学习过程的客户端总数(或在给定轮次中选定的子集)。
- k: 标识特定客户端的索引,范围从 1 到 N。
- Fk(w): 客户端 k 的局部目标函数。此函数衡量当前全局参数 w 在客户端 k 的本地数据集 Dk 上的表现。它量化了局部经验风险。
- pk: 分配给客户端 k 的权重,决定了它对全局目标的影响。一种常见的选择是根据客户端持有的数据量按比例加权。如果 nk=∣Dk∣ 是客户端 k 上的数据样本数量,并且 n=∑k=1Nnk 是所有客户端上的总样本数量,那么典型的加权方式是:
pk=nnk
这确保了贡献更多数据的客户端对最终全局模型有按比例更大的影响。存在其他加权方案,例如均匀加权(pk=1/N),如果数据大小未知,或者无论数据量如何都希望每个客户端做出同等贡献,则可能更偏好这种方式。请注意,通常要求 ∑k=1Npk=1。
局部目标函数
局部目标函数 Fk(w) 通常是参数为 w 的模型在客户端 k 的本地数据 Dk 上的平均损失。对于一个监督学习任务,数据点为 (xj,yj),其中 xj 是输入特征向量,yj 是目标标签,Fk(w) 可以表示为:
Fk(w)=nk1j∈Dk∑ℓ(w;xj,yj)
这里,ℓ(w;xj,yj) 是针对特定任务选择的损失函数,例如分类的交叉熵损失或回归的均方误差。它衡量单个数据点的预测误差。
从挑战部分重新提及的一点是,客户端之间的数据分布 (Dk) 通常不是独立同分布(Non-IID)的。这种统计异构性意味着局部目标函数 Fk(w) 彼此之间可能存在明显差异。一个客户端数据的最佳参数在另一个客户端数据上可能表现不佳。
优化目标
联邦优化过程的最终目标是找到一组全局参数 w∗,使全局目标函数 F(w) 最小化:
w∗=argwminF(w)=argwmink=1∑NpkFk(w)
解决这个最小化问题带来了一些独特的困难,与传统集中式机器学习相比:
- 数据去中心化: 数据集 Dk 仍保留在本地客户端上,无法在中央服务器上汇集。需要直接访问完整数据集的标准优化算法不适用。
- 通信限制: 优化必须通过迭代通信进行,在中央服务器(或协调器)和客户端之间。通信轮次通常缓慢且开销大,成为一个主要的瓶颈。
- 异构性: 如前所述,统计异构性(非独立同分布数据)意味着 ∇Fk(w)(局部梯度)可能是 ∇F(w)(全局梯度)的一个糟糕近似。系统异构性(计算能力、网络速度、可用性不同)使同步更新进一步复杂化。
联邦优化算法,例如广泛使用的联邦平均(FedAvg)算法,专门设计用于在这些约束下找到近似解 w∗。它们通常包括在客户端进行多轮本地计算(例如,对局部目标 Fk(w) 执行多步随机梯度下降),随后在服务器端聚合更新(例如,模型参数或梯度)以更新全局模型 w。
这种数学表述为联邦学习提供了一个清晰的目标。理解此目标是理解后续章节中讨论的高级算法设计与分析的第一步,这些算法旨在更高效、更稳定、更隐私地解决此问题。