System heterogeneity presents a significant hurdle in federated learning, particularly when dealing with a diverse population of client devices. As introduced earlier in this chapter, clients often possess vastly different computational capabilities (CPU power, available memory), energy budgets (especially for mobile devices), and network bandwidths. A large, complex model trained collaboratively might perform well in aggregate, but it could be entirely unusable on resource-constrained clients. Deploying a single, monolithic global model is often impractical.
This section examines techniques specifically designed to tailor model complexity to individual client capabilities within the federated learning process: model pruning and adaptation. These methods allow participation from a wider range of devices and can lead to more efficient and personalized federated systems.
Federated Model Pruning
Model pruning aims to reduce the size and computational cost of deep learning models by removing redundant or less important parameters. In a federated setting, pruning can be applied strategically to accommodate device limitations.
What is Pruning?
At its core, pruning identifies and eliminates model parameters (weights, neurons, filters, or even layers) that contribute least to the model's performance. This results in smaller model footprints, reduced memory usage, lower inference latency, and potentially decreased energy consumption. The challenge is to remove parameters without significantly degrading accuracy.
There are two primary categories of pruning:
-
Unstructured Pruning: This involves removing individual weights based on certain criteria, typically their magnitude (small magnitude weights are considered less important). The resulting weight matrices become sparse, containing many zero values.
- Pros: Can achieve high compression ratios with minimal accuracy loss.
- Cons: Sparse models often require specialized libraries (e.g., for sparse matrix multiplication) or hardware support to realize actual speedups. Standard hardware might not benefit much from unstructured sparsity. Communication of sparse updates can also be complex.
-
Structured Pruning: This method removes entire structural elements of the model, such as filters in convolutional layers, channels, or even entire layers.
- Pros: Results in smaller, dense models that can be executed efficiently on standard hardware without specialized libraries. Easier to implement and deploy.
- Cons: Can be less fine-grained than unstructured pruning, potentially leading to a larger drop in accuracy for the same level of parameter reduction. Requires careful selection of structures to remove.
Applying Pruning in Federated Learning
Integrating pruning into the FL workflow requires deciding when and where the pruning occurs:
- Server-Side Pruning: The central server prunes the aggregated global model before broadcasting it to clients for the next round. This is simpler to manage but applies a uniform pruning level to all clients, failing to adapt to individual device differences. It might prune too aggressively for capable devices or not enough for constrained ones.
- Client-Side Pruning: Each client receives the (potentially dense) global model and prunes it locally based on its own resource constraints before starting local training. Alternatively, clients might prune their calculated model updates before sending them back to the server. This allows for fine-grained adaptation but introduces complexity:
- How does the server aggregate updates from models with different structures (different sparsity masks or pruned architectures)? Averaging only the remaining, shared parameters is one approach, but it can be lossy.
- How is the pruning mask determined and communicated? Clients might need to report their capabilities or desired pruning level.
- Federated Pruning Algorithms: More sophisticated methods involve coordination. For example, clients might propose parameters to prune based on local data, and the server aggregates these proposals to make a global pruning decision. Some techniques try to learn a shared pruning mask across clients iteratively.
Example: Magnitude-Based Filter Pruning (Structured)
A common structured pruning technique for CNNs is filter pruning. Filters with lower importance scores are removed. A simple importance score for a filter Fi could be the L1 norm of its weights:
Importance(Fi)=∑∣w∣for w∈Fi
In an FL context:
- Server sends the current global model Wg.
- Client k computes filter importances based on Wg or after some local training.
- Client k determines its target pruning ratio based on device constraints.
- Client k creates a pruned model Wk′ by removing the lowest-importance filters.
- Client k trains Wk′ locally.
- Client k sends the update Δk′=Wk′−Wg′ (where Wg′ is the pruned version of Wg using the same mask as Wk′) back to the server.
- Server needs a strategy to aggregate these potentially differently structured updates. A common approach is to average updates only for the filters/weights that were not pruned by any participating client in that round, or using more advanced aggregation that accounts for the heterogeneous structures.
Model Adaptation Strategies
Beyond pruning, model adaptation techniques offer alternative ways to match model complexity with device capabilities.
Model Scaling and Sub-network Extraction
Instead of training a single fixed model, the idea is to train a flexible "supernetwork" that contains many sub-networks of varying sizes and complexities. Clients can then extract and fine-tune a sub-network that fits their resource budget.
- Once-for-All (OFA) Networks: A prominent example where a single large network is trained such that diverse sub-networks (obtained by varying depth, width, kernel size) can be extracted and perform well without requiring retraining.
- In FL: The server could coordinate the training of the OFA supernetwork. In each round, clients sample a sub-network according to their resources, perform local training on that sub-network, and send updates corresponding to the sampled architecture back. The server aggregates these updates back into the supernetwork parameters. This allows clients with different capabilities to contribute effectively to training a shared, scalable model representation. Aggregation requires careful mapping of sub-network updates back to the supernetwork weights.
Federated Knowledge Distillation (FKD)
Knowledge distillation involves training a smaller "student" model to mimic the behavior of a larger, pre-trained "teacher" model. In FL, the global model can act as the teacher.
- Process:
- The server trains a potentially large, high-accuracy global model Wglobal (the teacher).
- The server sends Wglobal (or just its architecture and parameters needed for inference) to clients.
- Each client k defines a smaller "student" model architecture Wstudent,k suitable for its device.
- During local training, client k trains its Wstudent,k not only on its local data Dk but also uses the outputs (e.g., logits) of the teacher model Wglobal on the same local data as a regularization target. The loss function might look like:
Lk=αLCE(Wstudent,k,Dk)+(1−α)LKD(Wglobal(x),Wstudent,k(x))for x∈Dk
where LCE is the standard cross-entropy loss and LKD is the distillation loss (e.g., Kullback-Leibler divergence between teacher and student output distributions).
- Advantages: Allows clients to train models tailored to their device constraints while benefiting from the knowledge encoded in the powerful global model.
- Challenges: Requires clients to run inference with the teacher model, which might still be too large for some devices (though only inference is needed, not backpropagation). Sharing logits can also have privacy implications, although generally considered less sensitive than gradients. Aggregation typically involves updating the global teacher model based on client contributions, which might be indirect (e.g., clients sending their improved student models, and the server using them to refine the teacher).
Adaptive Layer/Computation Skipping
Clients can dynamically skip parts of the model during execution based on available resources or real-time constraints.
- Layer Dropping: Randomly dropping layers during training acts as regularization. During inference, clients could deterministically drop specific layers (e.g., later, more complex layers) to meet latency budgets.
- Early Exits: Designing models with intermediate classifiers. Clients can exit computation early at these classifiers if a confident prediction is made or if resource limits are reached.
- In FL: The global model might be designed with droppable layers or early exits. Clients decide locally which layers to use or where to exit based on their constraints. Aggregation needs to handle updates coming from clients that used different computational paths.
Integrating Adaptation into the FL System
Successfully implementing pruning and adaptation requires system-level support:
- Capability Reporting: Clients need a mechanism to report their resource constraints (e.g., memory, CPU FLOPS, target latency) to the server. This could happen during an initial handshake or periodically.
- Server Orchestration: The server must manage the potentially heterogeneous model configurations. This might involve:
- Maintaining different pruned versions of the global model.
- Managing the parameters of a supernetwork and mapping sub-network updates.
- Coordinating knowledge distillation processes.
- Heterogeneous Aggregation: The core FL challenge is amplified. Simple averaging might not work if clients have structurally different models or updates. Techniques include:
- Averaging only shared parameters.
- Weighted averaging based on the "size" or contribution of the client's model.
- Using knowledge-based aggregation where client outputs or model representations are aggregated rather than raw parameters.
Trade-offs and Considerations
Implementing these techniques involves balancing several factors:
-
Accuracy vs. Efficiency: There is almost always a trade-off. Aggressive pruning or using very small adapted models will reduce resource usage but likely decrease maximum achievable accuracy. Visualizing this trade-off is helpful.
Accuracy often decreases as model size is reduced through pruning or adaptation, requiring careful balancing based on application needs and device constraints.
-
System Complexity: These methods add significant complexity to the FL system design, implementation, and debugging compared to standard FedAvg with a homogeneous model.
-
Communication: While the goal is often smaller models on device, some techniques might introduce new communication overhead (e.g., sending capability profiles, exchanging knowledge distillation targets, communicating pruning masks or architecture choices).
-
Fairness: Ensure that clients with severely constrained devices are not left behind. Their adapted models should still provide meaningful utility, and their contributions should be appropriately incorporated into the global learning process.
By carefully selecting and implementing model pruning and adaptation strategies, federated learning systems can become more inclusive of diverse hardware, leading to more practical and efficient deployments in real-world scenarios characterized by system heterogeneity.