趋近智
从零开始实现联邦学习系统需要管理复杂的分布式通信、客户端调度、状态同步和安全聚合协议。这是一项重大的工程任务,会分散人们对核心机器学习工作的精力。幸好,一些开源框架已经出现,它们将这些复杂性抽象化,为构建和模拟联邦学习系统提供了可重用组件和标准化工作流程。这些框架使得研究人员和工程师能更侧重于设计新颖的算法、评估隐私机制以及部署联邦学习应用。
三个有名的框架是:TensorFlow Federated (TFF)、PySyft 和 Flower。每个框架都提供不同的理念和抽象集合,以适应不同的应用场景和开发偏好。了解它们的主要思想和功能对于为您的联邦学习项目选择合适的工具非常重要。
TensorFlow Federated (TFF) 由谷歌开发,旨在支持联邦学习的开放研究和实验。它与 TensorFlow 紧密集成,让您可以在联邦环境中利用熟悉的 TensorFlow/Keras 模型和 API。TFF 提供两个主要的 API 层:
tff.learning):这是一个更高级的 API,为常见的联邦学习任务(特别是模型训练和评估)提供预构建组件。它提供像 tff.learning.algorithms.build_weighted_fed_avg 这样的接口,实现了联邦平均等标准算法。这一层简化了常见联邦学习场景的实现,并减少了样板代码。tff.program, tff.computation):这是一个更低层的 API,提供用于表示联邦计算的基础构建块。它允许精细控制计算执行的位置(服务器或客户端)以及数据如何通信和聚合。使用 FC API,您可以实现自定义的联邦算法。计算表示为 tff.Computation 对象,通常使用像 @tff.federated_computation 这样的 Python 装饰器来构建。TFF 的主要优势在于它与 TensorFlow 生态系统的紧密集成,以及其用于表示新颖联邦计算的强大低层 API,使其非常适合研究用途。它包含模拟功能,让您可以有效地对异构客户端群体和网络条件进行建模。
TFF 结构(FL API):
# 假设 `client_data` 是一个 tf.data.Dataset 列表
# 假设 `model_fn` 返回一个未编译的 Keras 模型
# 1. 定义迭代过程(例如,联邦平均 FedAvg)
trainer = tff.learning.algorithms.build_weighted_fed_avg(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1),
server_optimizer_fn=lambda: tf.keras.optimizers.Adam(learning_rate=0.01)
)
# 2. 初始化服务器状态
state = trainer.initialize()
# 3. 运行联邦回合
for round_num in range(NUM_ROUNDS):
# 为本回合采样客户端数据
sampled_data = [client_data[i] for i in sample_clients(round_num)]
# 执行一轮联邦训练
result = trainer.next(state, sampled_data)
state = result.state
# 处理指标等
print(f"回合 {round_num}, 指标: {result.metrics}")
TFF 主要侧重于同步联邦计算,并且大量用于模拟,尽管部署 TFF 计算需要额外的基础设施。
PySyft 由 OpenMined 社区开发,主要侧重于支持安全和隐私 AI。虽然它支持联邦学习,但其范围扩展到其他隐私增强技术,如差分隐私 (DP)、安全多方计算 (SMC) 和同态加密 (HE),这些技术通常直接集成到联邦学习工作流程中。PySyft 旨在实现框架无关性,但其目前最成熟的支持是针对 PyTorch。
PySyft 采用以面向对象的方式,围绕以下思想展开:
VirtualWorker 用于模拟。PointerTensor 对象充当驻留在远程工作节点上的数据的引用。对指针张量的操作会被转发到相应的工作节点执行。AdditiveSharingTensor(用于 SMC)或应用 DP 的机制等对象被集成到张量系统中。Plan 对象封装了一系列操作(如训练步骤),可以发送到工作节点并在其数据上远程执行。Protocol 对象协调多个工作节点之间的复杂交互,常用于安全聚合。PySyft 的优势在于其隐私优先的设计以及将各种密码学技术直接集成到框架中。它提供构建块来构建复杂的隐私保护联邦学习系统。
PySyft 结构:
import torch
import syft as sy
# 1. 挂接 PyTorch 并创建工作节点
hook = sy.TorchHook(torch)
server = sy.VirtualWorker(hook, id="server")
client1 = sy.VirtualWorker(hook, id="client1")
client2 = sy.VirtualWorker(hook, id="client2")
# 2. 创建数据并发送到客户端(使用 PointerTensors)
data1 = torch.tensor([1, 2, 3]).send(client1)
data2 = torch.tensor([4, 5, 6]).send(client2)
# data1, data2 现在是 PointerTensors
# 3. 定义模型并发送到客户端
model = torch.nn.Linear(3, 1)
model_ptr1 = model.copy().send(client1)
model_ptr2 = model.copy().send(client2)
# 4. 定义一个计划(例如,训练步骤)
# @sy.func2plan() # 将函数转换为计划的装饰器
# def train_step(data, model): ... 返回损失, 更新后的模型
# 5. 构建计划,发送到客户端,并执行
# plan = train_step.build(...)
# plan.send(client1)
# loss1, updated_model1_ptr = plan(data1, model_ptr1)
# 6. (如果需要)安全地检索更新后的模型并聚合
# updated_model1 = updated_model1_ptr.get() ...
# 在服务器上聚合模型
PySyft 具有很高的灵活性,但由于它侧重于底层隐私机制和分布式计算抽象,因此学习曲线可能更陡峭。
Flower 是一个较新的框架,其主要设计目标是实现框架无关性和易于集成。它旨在通过允许开发人员对现有机器学习代码(用 PyTorch、TensorFlow、scikit-learn、JAX 等编写)进行最少修改,从而使联邦学习更易于使用。
Flower 采用清晰的客户端-服务器架构:
Strategy 对象定义。Flower 提供预构建的策略(例如,FedAvg、FedAdam、QFedAvg),但也允许定义自定义策略以实现新颖的聚合方法、客户端选择逻辑或其他联邦协调模式。flwr.client.Client 或 flwr.client.NumPyClient 类,将现有数据加载、模型训练和评估逻辑封装在特定方法(get_parameters、fit、evaluate)中。这种分离使得客户端可以在核心训练循环中运行标准的机器学习代码,而无需 Flower 特定的数据结构或模型类型。服务器与客户端通信,发送全局模型参数或指令,客户端则响应更新或评估结果。
Flower 结构:
import flwr as fl
import tensorflow as tf # 或者 PyTorch, scikit-learn 等
# --- 客户端代码 (client.py) ---
class MyFlowerClient(fl.client.NumPyClient):
def __init__(self, model, x_train, y_train, x_val, y_val):
self.model = model
self.x_train, self.y_train = x_train, y_train
self.x_val, self.y_val = x_val, y_val
def get_parameters(self, config):
# 返回模型权重,作为 NumPy ndarray 列表
return self.model.get_weights()
def fit(self, parameters, config):
# 设置来自服务器的模型权重,在本地训练
self.model.set_weights(parameters)
# 使用标准的 TF/PyTorch 训练循环
self.model.fit(self.x_train, self.y_train, epochs=1, batch_size=32)
# 返回更新后的权重、样本数量和指标
return self.model.get_weights(), len(self.x_train), {}
def evaluate(self, parameters, config):
# 设置模型权重,在本地验证集上评估
self.model.set_weights(parameters)
loss, accuracy = self.model.evaluate(self.x_val, self.y_val)
# 返回损失、样本数量和指标
return loss, len(self.x_val), {"accuracy": accuracy}
# 启动 Flower 客户端
# fl.client.start_numpy_client(server_address="[::]:8080", client=MyFlowerClient(...))
# --- 服务器代码 (server.py) ---
# 定义一个策略(例如,联邦平均 FedAvg)
strategy = fl.server.strategy.FedAvg(
fraction_fit=1.0, # 训练时采样所有可用客户端的 100%
min_fit_clients=2, # 训练所需的最小客户端数量
min_available_clients=2, # 等待至少 2 个客户端连接
)
# 启动 Flower 服务器
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy
)
Flower 的优势包括其易用性、集成各种机器学习框架的灵活性,以及其支持模拟和向部署过渡(包括通过 SDK 在移动/物联网设备上部署)的设计。它的 Strategy API 提供了一种简洁的方式来定制联邦方面,而无需显著改变客户端的机器学习代码。
最佳框架取决于您的具体需求:
| 特性 | TensorFlow Federated (TFF) | PySyft | Flower |
|---|---|---|---|
| 主要侧重 | 研究、模拟 | 隐私(DP、SMC、HE)、研究 | 集成、部署、研究 |
| 机器学习后端 | TensorFlow(主要) | PyTorch(主要)、TF(部分) | 无关(TF、PyTorch、JAX 等) |
| 集成便捷性 | 中等(需要 TFF 结构) | 中等(需要 Syft 对象) | 高(适应现有代码) |
| 隐私功能 | 良好(DP 集成) | 优秀(侧重 DP、SMC、HE) | 良好(通过 Strategy API) |
| 灵活性 | 高(FC API) | 高(协议、计划) | 高(Strategy API) |
| 部署 | 侧重模拟 | 侧重模拟/研究 | 侧重模拟和部署 |
TFF、PySyft 和 Flower 主要特点的比较。
这些框架抽象了分布式系统工程中涉及的许多底层细节,让您能够专注于联邦学习的独有方面:算法设计、隐私保护、异构性处理和系统评估。对于任何在联邦学习领域认真工作的人来说,熟悉这些工具中的至少一个正变得越来越重要。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造