Now that we've discussed the theoretical underpinnings of statistical heterogeneity and various approaches to mitigate its effects, let's put this knowledge into practice. Simulating federated learning scenarios, especially those involving Non-IID data, is an essential step in algorithm development and evaluation. It allows us to understand how different strategies perform under controlled, yet realistic, conditions before considering complex real-world deployments.
This hands-on exercise will guide you through simulating a Non-IID data distribution and comparing the performance of standard Federated Averaging (FedAvg) against a mitigation technique, specifically FedProx, which we covered in Chapter 2.
Most federated learning simulations involve three core components: a central server, a set of clients, and a dataset partitioned across these clients. You can use frameworks like TensorFlow Federated (TFF), PySyft, or Flower (as detailed in Chapter 6) to facilitate this setup. For this exercise, we'll focus on the conceptual steps, which are applicable across different frameworks.
We typically start with a standard dataset suitable for classification, such as MNIST or CIFAR-10. The key step is how we distribute this data among the simulated clients to create statistical heterogeneity.
Creating realistic Non-IID data distributions is critical for meaningful simulations. A common and effective method is to use a Dirichlet distribution to allocate class labels among clients.
Imagine we have N clients and C data classes. We can model the distribution of classes on client k using a probability vector pk=(pk,1,...,pk,C), where pk,j is the proportion of samples belonging to class j on client k. To generate heterogeneous distributions, we can draw each pk from a Dirichlet distribution with a concentration parameter α:
pk∼Dir(α)
Conceptual Implementation Steps:
Other ways to introduce heterogeneity include varying the quantity of data per client (quantity skew) or partitioning based on underlying data features, but label distribution skew using Dirichlet is a widely adopted standard for benchmarking.
First, run a standard FedAvg simulation using the Non-IID data partition created above. Train a suitable model (e.g., a simple CNN for CIFAR-10) over several communication rounds.
Key steps in the FedAvg round:
Monitor the global model's accuracy on a held-out, representative test set over the communication rounds. You will likely observe slower convergence and potentially lower final accuracy compared to simulations run on IID data, especially with small α values. This performance degradation highlights the challenge posed by Non-IID data.
Now, let's implement FedProx to address the client drift caused by heterogeneity. Recall from Chapter 2 that FedProx modifies the local client objective function by adding a proximal term. Instead of just minimizing the local empirical loss Fk(w)=nk1∑i∈Dkℓ(w;xi,yi), the client optimizes:
minwHk(w)=Fk(w)+2μ∥w−wt∥2
Here, wt is the global model received from the server at the start of the round t, and μ≥0 is a hyperparameter controlling the strength of the proximal term. This term limits how far the local model w can stray from the initial global model wt during local training.
Conceptual Implementation Changes:
Run the simulation again using the same Non-IID data partition but with the FedProx client update rule. Use a non-zero value for μ (e.g., μ=0.1 or μ=1.0 as starting points).
After running both simulations (FedAvg and FedProx) on the same Non-IID data setup, compare their performance. Plot the global model accuracy on the test set against the communication round number for both algorithms.
Comparison of test accuracy over communication rounds for FedAvg and FedProx on a synthetically generated Non-IID dataset (e.g., CIFAR-10 with Dirichlet partitioning, α=0.3). FedProx often shows improved stability and convergence compared to FedAvg under heterogeneity.
The simulation results, visualized in the chart above, typically demonstrate that FedProx can lead to more stable convergence and potentially higher final accuracy compared to FedAvg when dealing with significant statistical heterogeneity. The proximal term effectively restricts clients from diverging too much during local training, mitigating the negative impact of Non-IID data on the global model aggregation.
This exercise provides a foundation for experimenting with heterogeneity:
By systematically simulating these scenarios, you gain valuable insights into the strengths and weaknesses of different federated learning strategies, paving the way for designing more effective and reliable systems for real-world applications.
© 2025 ApX Machine Learning