As discussed in the chapter introduction and the previous section on Federated Averaging's limitations, the standard FedAvg algorithm can struggle when client data distributions are statistically heterogeneous (Non-IID). When clients perform multiple local updates on their distinct datasets before aggregation, their local models can diverge significantly from each other and from an optimal global solution. This phenomenon is often referred to as "client drift". FedAvg's simple averaging step does not explicitly counteract this drift, potentially leading to slower convergence, oscillations, or even divergence in severe cases.
FedProx (Federated Optimization with Proximal term) is an algorithm specifically designed to address this challenge. It introduces a modification to the local optimization problem solved by each client, aiming to limit how far each local model can move away from the current global model during local training.
The Proximal Term
The central idea of FedProx is to add a proximal term to the standard local objective function. Instead of simply minimizing their local loss Fk(w) based on their local data Dk, each client k now solves a modified objective during local training in communication round t:
wkminLprox(wk;wt)=Fk(wk)+2μ∣∣wk−wt∣∣2
Here:
- wk represents the model parameters being optimized locally by client k.
- Fk(wk) is the standard local loss function for client k (e.g., cross-entropy loss on its local data).
- wt is the global model received from the server at the beginning of round t.
- ∣∣⋅∣∣2 denotes the squared Euclidean norm (L2 norm).
- μ≥0 is a hyperparameter controlling the strength of the proximal term.
How it Works: Constraining Local Updates
The term 2μ∣∣wk−wt∣∣2 acts as a regularization penalty. It penalizes large deviations between the client's updated model wk and the initial global model wt it started the round with.
Think of it intuitively: while the client tries to minimize its local loss Fk(wk) based on its specific data, the proximal term pulls the solution wk back towards the starting point wt. This discourages the local model from overfitting to the local data distribution and drifting too far away from the global consensus represented by wt.
The hyperparameter μ balances the trade-off:
- If μ=0, FedProx reduces exactly to FedAvg. The local optimization only considers the local loss Fk(wk).
- If μ>0, the local updates are constrained. A larger μ imposes a stronger penalty on deviating from wt, leading to local models wk that stay closer to the initial global model. This increases stability against heterogeneity but might slightly slow down the fitting process to the local data if μ is excessively large.
Finding an appropriate value for μ often requires empirical tuning, similar to other regularization hyperparameters in machine learning.
FedProx Algorithm Flow
The overall flow of FedProx is similar to FedAvg, with the main difference occurring during the client update step:
- Server Initialization: Initialize global model w0.
- Communication Rounds (t = 0, 1, 2, ...):
- Server Broadcast: Server sends the current global model wt to a selected subset of clients St.
- Client Local Computation: Each selected client k∈St:
- Receives wt.
- Performs local updates (e.g., using SGD) for a certain number of steps (E epochs) to approximately solve the modified local objective: minwkFk(wk)+2μ∣∣wk−wt∣∣2. Let the resulting local model be wkt+1.
- Sends the updated model wkt+1 back to the server. Note: FedProx does not require clients to solve the local problem exactly.
- Server Aggregation: Server aggregates the received local models to form the new global model wt+1. Typically, weighted averaging based on local dataset sizes nk is used, similar to FedAvg:
wt+1=k∈St∑nnkwkt+1where n=k∈St∑nk
Benefits and Considerations
Advantages:
- Improved Stability on Non-IID Data: By mitigating client drift, FedProx often leads to more stable and reliable convergence compared to FedAvg when faced with statistical heterogeneity.
- Theoretical Convergence Guarantees: FedProx comes with convergence guarantees even under heterogeneous data distributions and varying local computation effort across clients.
- Tolerance to System Heterogeneity: Although primarily targeting statistical heterogeneity, the proximal term also makes the algorithm more tolerant to variations in the amount of computation performed locally by different clients (e.g., varying numbers of local epochs Ek). Clients performing fewer updates naturally stay closer to wt, fitting the proximal objective.
Considerations:
- Hyperparameter Tuning: The performance of FedProx depends on the choice of μ, which needs to be tuned for the specific problem and dataset.
- Implementation: Implementing FedProx requires modifying the client-side optimization loop to include the proximal term. This is generally straightforward in most deep learning frameworks by adding an L2 regularization term relative to the initial model wt for the round.
The following chart illustrates conceptually how FedProx might lead to more stable convergence compared to FedAvg in a Non-IID setting.
FedAvg (red line) may exhibit more oscillatory behavior or slower convergence on heterogeneous data, while FedProx (blue line) with an appropriate μ can achieve smoother and potentially higher final accuracy by mitigating client drift.
FedProx represents a significant step beyond FedAvg by explicitly incorporating a mechanism to handle statistical heterogeneity. While not a complete solution for all heterogeneity challenges, its relative simplicity and effectiveness make it a valuable tool in the federated learning practitioner's arsenal. In the following sections, we will examine other advanced aggregation techniques like SCAFFOLD and FedNova, which use different mechanisms to address heterogeneity and improve convergence.