While FedAvg provides a simple baseline for aggregating client updates, its performance often degrades significantly in the presence of statistical heterogeneity (Non-IID data). Clients training on vastly different local datasets can pull the global model in conflicting directions, leading to slow convergence, oscillations, or even divergence. This phenomenon, often called "client drift," stems from the discrepancy between the local optimization objective each client pursues and the global objective.
SCAFFOLD (Stochastic Controlled Averaging for Federated Learning) offers an elegant solution to mitigate this client drift by introducing control variates. The core idea is to estimate the update direction each client would have taken if it had access to the global data distribution and then correct the actual local updates based on this estimate. This reduces the variance between client updates, aligning them more closely with the global objective and thereby accelerating convergence.
The Mechanics of SCAFFOLD
SCAFFOLD modifies the standard federated learning process by maintaining additional state information on both the server and the clients:
- Server Control Variate (c): Represents an estimate of the average gradient direction across all data in the federation.
- Client Control Variates (ck): Each client k maintains its own control variate ck, representing an estimate of the gradient direction based on its local data.
The difference c−ck effectively captures the "drift" for client k. SCAFFOLD uses these control variates to adjust the local training process and the aggregation step.
Algorithm Overview
Let wt be the global model weights at communication round t.
-
Server Broadcast: The server sends the current global model wt and its control variate ct to the selected clients.
-
Client Computation: Each participating client k performs the following steps:
- Initializes its local model wkt=wt.
- Receives the server control variate ct.
- Performs E local steps of stochastic gradient descent (SGD). For each local step τ=0,...,E−1, using a local minibatch b:
- Compute the local gradient: gk(wk,τt)=∇Fk(wk,τt;b), where Fk is the local loss function for client k.
- Update the local model using the corrected gradient:
wk,τ+1t=wk,τt−ηl(gk(wk,τt)−ckt+ct)
Here, ηl is the local learning rate. Notice how the local gradient gk is adjusted by the difference between the server control variate ct and the client's control variate ckt. This adjustment aims to correct for the client's local drift.
- Compute the total model update direction: Δwkt=wk,Et−wt.
- Update the client control variate ck. A common way is:
ckt+1=ckt−ct+Eηl1(wt−wk,Et)
This update reflects the average local gradient direction observed during the E steps.
- Compute the change in the client control variate: Δckt=ckt+1−ckt.
- Send Δwkt and Δckt back to the server.
-
Server Aggregation: The server receives updates (Δwkt,Δckt) from a set St of participating clients.
- Aggregate model updates (similar to FedAvg):
Δwt=∣St∣1k∈St∑Δwkt
- Update the global model:
wt+1=wt+ηgΔwt
(ηg is the server learning rate, often set to 1).
- Aggregate control variate updates:
Δct=∣St∣1k∈St∑Δckt
- Update the server control variate:
ct+1=ct+Δct
Why Does Variance Reduction Work?
The key intuition is that the term (gk(wk,τt)−ckt+ct) used in the client update is a better estimate of the global gradient ∇F(wk,τt) than the raw local gradient gk(wk,τt). By subtracting the estimated local direction (ckt) and adding the estimated global direction (ct), SCAFFOLD effectively steers the client updates towards the global minimum, reducing the variance caused by disparate local data distributions.
This variance reduction leads to more stable and typically faster convergence compared to FedAvg, especially when dealing with significant statistical heterogeneity (Non-IID data).
Convergence behavior of FedAvg, FedProx, and SCAFFOLD under simulated non-IID conditions. SCAFFOLD often achieves faster convergence and potentially higher final accuracy due to its variance reduction mechanism.
Implementation Considerations
Implementing SCAFFOLD introduces some overhead compared to FedAvg:
- State: Both the server and clients need to store their respective control variates (c and ck). These are typically the same size as the model parameters.
- Communication: Clients need to send both the model update (Δwk) and the control variate update (Δck) to the server in each round. The server needs to broadcast the global control variate c in addition to the global model w. This effectively doubles the communication cost per round compared to FedAvg, although the potential reduction in the number of rounds required for convergence can often compensate for this.
Despite the increased state and communication requirements, SCAFFOLD's ability to effectively handle client drift makes it a valuable tool for improving the performance and reliability of federated learning systems operating on heterogeneous data. Its theoretical underpinnings provide stronger convergence guarantees than FedAvg under non-IID settings.