Federated Learning (FL) represents a significant shift from traditional centralized machine learning. Instead of collecting vast amounts of user data into a central repository for model training, FL enables collaborative model building directly on distributed devices or within separate organizational silos, all while keeping the raw data localized. This approach intrinsically enhances data privacy, as sensitive information never leaves the client's control. Before we examine advanced techniques, let's solidify our understanding of the core principles that underpin this distributed learning paradigm.
The Federated Learning Workflow
The most common and foundational approach to FL is Federated Averaging (FedAvg). It follows an iterative process coordinated by a central server, typically involving these steps:
- Initialization: The server starts with an initial global model, often randomly initialized or pre-trained on a public dataset.
- Client Selection: In each communication round t, the server selects a subset of available clients (e.g., mobile devices, hospitals) to participate in training. Selection strategies can vary, often involving random sampling.
- Model Distribution: The server transmits the current global model parameters, wt, to the selected clients.
- Local Training: Each selected client k updates the received model using its local data Dk. This typically involves running multiple steps of gradient descent (or a variant) on its local objective function Fk(w), derived from its private data. Let wt+1k be the updated local model parameters on client k.
- Model Update Transmission: Clients send their computed updates back to the server. This could be the full updated model parameters wt+1k, or more commonly, the difference Δk=wt+1k−wt, or gradients ∇Fk(wt). This transmission is a potential communication bottleneck and privacy risk area.
- Aggregation: The server aggregates the received updates from the participating clients to compute a new global model wt+1. In FedAvg, this is typically a weighted average based on the amount of data each client used for training:
wt+1=k∈St∑nnkwt+1k
where St is the set of selected clients in round t, nk=∣Dk∣ is the number of data points on client k, and n=∑k∈Stnk is the total number of data points across selected clients. Alternatively, if updates Δk are sent:
wt+1=wt+k∈St∑nnkΔk
- Iteration: The process repeats from Step 2 for a predefined number of communication rounds or until a convergence criterion is met.
This cyclical process allows the global model to learn from the collective knowledge embedded within the distributed datasets without centralizing them.
A diagram illustrating the standard synchronous Federated Learning cycle involving the server and a representative client.
Key Entities: Server and Clients
The FL ecosystem primarily consists of two types of entities:
- Clients: These are the devices or organizations (e.g., smartphones, laptops, hospitals, banks) holding the local data. They possess computational resources to perform local model training based on instructions from the server. Their participation might be intermittent, and their resources (CPU, network bandwidth, data size) can vary significantly, leading to systems heterogeneity.
- Server: This central entity orchestrates the learning process. It initializes the model, selects clients, distributes the model, aggregates updates, and maintains the global state. While the server does not access raw client data, its role is critical for coordination and convergence. The server itself can be a potential point of failure or attack.
The Optimization Goal
As mentioned in the chapter introduction, the overarching goal is typically to minimize a global objective function F(w) which represents the aggregate performance across all N clients:
wminF(w)whereF(w)=k=1∑NpkFk(w)
Here, Fk(w)=L(w;Dk) is the local loss function for client k computed over its local dataset Dk, and pk is a weight assigned to client k, usually pk=nk/∑j=1Nnj, where nk=∣Dk∣. This formulation highlights that we aim for a model that performs well on average across the data distribution represented by all clients.
Contrasting with Other Learning Paradigms
It's important to distinguish FL from other approaches:
- Centralized Learning: Requires pooling all data in one location, which FL explicitly avoids due to privacy, communication, or regulatory constraints.
- Classical Distributed Learning: Often assumes data is distributed across nodes in a cluster (e.g., using parameter servers), but typically within a single trusted domain, with more homogeneous node capabilities and often IID (Independent and Identically Distributed) data partitions. FL specifically targets scenarios with higher heterogeneity, untrusted environments (potentially), and non-IID data.
This recap of the basic FL principles and the standard FedAvg workflow sets the stage. While conceptually straightforward, this foundational model operates under several simplifying assumptions. Real-world federated environments introduce significant challenges related to data and systems heterogeneity, privacy vulnerabilities, communication costs, and potential adversarial behavior, motivating the advanced techniques we will explore throughout this course.