Now that we've explored the theoretical underpinnings of advanced aggregation algorithms like FedProx and SCAFFOLD, let's put them into practice. This section guides you through implementing FedProx, demonstrating how it addresses some of FedAvg's limitations, particularly in non-IID settings. We will compare its performance against the standard FedAvg baseline in a simulated environment.
We assume you have a basic federated learning simulation setup. This typically involves:
Our baseline will be a standard FedAvg implementation where clients train locally for E epochs using SGD and send their updated model weights back to the server for averaging.
Recall from the "FedProx: Addressing Statistical Heterogeneity" section that FedProx modifies the client's local objective function. Instead of just minimizing the local loss Fk(w) on client k's data, the client minimizes:
wminFk(w)+2μ∣∣w−wt∣∣2Here, wt represents the global model weights received from the server at the start of round t, and μ is a non-negative hyperparameter controlling the strength of the proximal term. This term pulls the local solution w closer to the global model wt, mitigating client drift caused by diverging local data distributions.
In practice, implementing this involves modifying the client's local training loop. Specifically, when computing the gradient for SGD (or any optimizer), you need to add the gradient of the proximal term. The gradient of the proximal term with respect to w is simply μ(w−wt).
Here's a conceptual Python snippet illustrating the modification within a client's local training step (assuming PyTorch):
# Assume 'model' is the client's local model instance
# 'global_model_weights' holds the weights w^t received from the server
# 'optimizer' is the standard SGD optimizer
# 'criterion' is the loss function (e.g., CrossEntropyLoss)
# 'local_data_loader' provides batches of local data
# 'mu' is the FedProx hyperparameter
# Store the initial global weights before local training
initial_global_weights = [p.clone().detach() for p in model.parameters()]
# Standard local training loop
for epoch in range(num_local_epochs):
for batch_idx, (data, target) in enumerate(local_data_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
# --- FedProx Modification Starts ---
proximal_term = 0.0
# Iterate over model parameters and the initial global weights
for local_param, global_param in zip(model.parameters(), initial_global_weights):
# Ensure parameters requiring gradients are included
if local_param.requires_grad:
# Calculate the squared L2 norm difference
proximal_term += torch.sum((local_param - global_param.to(local_param.device))**2)
loss += (mu / 2.0) * proximal_term
# --- FedProx Modification Ends ---
loss.backward()
optimizer.step()
# After local training, the 'model' contains the updated weights
# to be sent back to the server (or the delta, depending on implementation)
Key Implementation Points:
The server-side aggregation in FedProx remains identical to FedAvg: it collects the updated local models (or model deltas) and computes a weighted average based on the number of data points on each client.
To observe the impact of FedProx, set up a simulation with significant statistical heterogeneity. For example, partition the MNIST dataset among 100 clients such that each client only has data from two digit classes. Train a simple CNN using both FedAvg (μ=0) and FedProx (e.g., μ=0.01 or μ=0.1) for a fixed number of communication rounds.
Track the global model's accuracy on a held-out, balanced test set after each communication round. You might observe results similar to those depicted below:
Global model accuracy on a balanced MNIST test set over communication rounds for FedAvg and FedProx with different μ values, simulated under a non-IID data distribution (each client has data from only two classes).
The chart illustrates a typical outcome:
This practical exercise demonstrates the tangible benefits of using an advanced aggregation algorithm like FedProx in heterogeneous environments.
Further Exploration:
By implementing and experimenting with these algorithms, you gain practical insight into their behavior and the trade-offs involved, preparing you to build more effective federated learning systems for real-world applications.
© 2025 ApX Machine Learning