趋近智
FedProx 和 SCAFFOLD 等高级聚合算法为联邦学习提供了改进方案。实现 FedProx 的方法将得到演示,说明它如何解决 FedAvg 的一些局限性,特别是在非独立同分布 (non-IID) 的情况下。在模拟环境中,其性能将与标准的 FedAvg 基线进行比较。
我们假设你已具备基本的联邦学习模拟环境。这通常包括:
我们的基线将是标准的 FedAvg 实现,其中客户端使用 SGD 在本地训练 个周期,并将更新后的模型权重 (weight)发送回服务器进行平均。
回顾“FedProx:处理统计异质性”一节,FedProx 修改了客户端的本地目标函数。客户端不是简单地最小化客户端 数据上的本地损失 ,而是最小化:
这里, 表示在第 轮开始时从服务器接收到的全局模型权重 (weight), 是一个非负超参数 (parameter) (hyperparameter),用于控制近端项的强度。此项将本地解 拉向全局模型 ,从而减轻因本地数据分布差异造成的客户端漂移。
实际操作中,实现这一点需要修改客户端的本地训练循环。具体来说,在计算 SGD(或任何优化器)的梯度时,你需要加上近端项的梯度。近端项对 的梯度就是 。
以下是一个 Python 代码片段,展示了客户端本地训练步骤中的修改(假设使用 PyTorch):
# 假设 'model' 是客户端的本地模型实例
# 'global_model_weights' 存储从服务器接收到的权重 w^t
# 'optimizer' 是标准 SGD 优化器
# 'criterion' 是损失函数(例如,CrossEntropyLoss)
# 'local_data_loader' 提供本地数据批次
# 'mu' 是 FedProx 超参数
# 在本地训练前存储初始全局权重
initial_global_weights = [p.clone().detach() for p in model.parameters()]
# 标准本地训练循环
for epoch in range(num_local_epochs):
for batch_idx, (data, target) in enumerate(local_data_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# --- FedProx 修改开始 ---
proximal_term = 0.0
# 遍历模型参数和初始全局权重
for local_param, global_param in zip(model.parameters(), initial_global_weights):
# 确保包含需要梯度的参数
if local_param.requires_grad:
# 计算 L2 范数平方差
proximal_term += torch.sum((local_param - global_param.to(local_param.device))**2)
loss += (mu / 2.0) * proximal_term
# --- FedProx 修改结束 ---
loss.backward()
optimizer.step()
# 本地训练后,'model' 包含更新后的权重
# 将发送回服务器(或增量,取决于实现方式)
实现要点:
FedProx 中的服务器端聚合与 FedAvg 保持一致:它收集更新后的本地模型(或模型增量),并根据每个客户端上的数据点数量计算加权平均值。
为观察 FedProx 的影响,请设置一个具有明显统计异质性的模拟。例如,将 MNIST 数据集分配给 100 个客户端,使每个客户端仅拥有两个数字类别的数据。使用 FedAvg () 和 FedProx(例如, 或 )训练一个简单 CNN,进行固定数量的通信轮次。
在每个通信轮次后,跟踪全局模型在独立、平衡测试集上的准确率。你可能会看到与下述图示类似的结果:
FedAvg 和 FedProx 在不同 值下,全局模型在平衡 MNIST 测试集上的准确率随通信轮次的变化,模拟了非独立同分布数据(每个客户端仅拥有两个类别的数据)。
该图表显示了常见的成果:
这项实践练习表明了在异质环境中采用 FedProx 等高级聚合算法的实际优势。
后续研究:
“通过实现和实验这些算法,你将获得对其行为和相关权衡的实际理解,为你在各种应用中构建更有效的联邦学习系统做好准备。”
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•