Having explored the theoretical underpinnings and architectural considerations of federated learning systems, it's time to translate that knowledge into practice. This section guides you through setting up a fundamental FL simulation using a common framework. This hands-on exercise bridges the gap between the concepts discussed earlier, such as system components and framework roles, and the actual code required to run a federated process.
We'll use the Flower (flwr) framework for this demonstration. Flower is known for its flexibility, allowing integration with various machine learning libraries like TensorFlow and PyTorch, and its relatively straightforward API for defining client and server logic. It abstracts away much of the low-level communication handling, letting you focus on the FL strategy and client-side ML tasks.
Before starting, ensure you have Python installed along with Flower and a deep learning library. Flower can be installed using pip. For this example, we'll also need TensorFlow (or PyTorch, adapting the client code accordingly) and NumPy.
pip install flwr[simulation] tensorflow numpy
# Or, if using PyTorch:
# pip install flwr[simulation] torch torchvision numpy
We assume you have a working Python environment (version 3.8 or higher recommended) and are comfortable with basic ML model definition and training in your chosen deep learning framework.
Our goal is to simulate training a simple Convolutional Neural Network (CNN) on the MNIST dataset distributed across multiple virtual clients. We will implement:
NumPyClient
that wraps local model training and evaluation.start_simulation
function with a basic FedAvg
strategy.For simplicity in this initial setup, we'll assume an IID (Independent and Identically Distributed) data partitioning, where each client receives a random subset of the global dataset. Techniques for handling Non-IID data were covered in Chapter 4 and can be incorporated as a next step.
NumPyClient
)Flower's NumPyClient
provides a convenient abstraction. You implement methods that receive and return model parameters as lists of NumPy arrays, allowing easy integration with frameworks like TensorFlow or PyTorch.
First, let's define a simple CNN model using TensorFlow Keras (adapt if using PyTorch).
import tensorflow as tf
def create_simple_cnn():
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (5, 5), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (5, 5), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
# Assume MNIST data (x_train, y_train), (x_test, y_test) is loaded and preprocessed
# (e.g., normalized, reshaped to (num_samples, 28, 28, 1))
# For simulation, we need to partition this data among clients.
# Let's assume a function `load_partition(client_id, num_clients)` exists
# that returns (x_train_cid, y_train_cid), (x_test_cid, y_test_cid) for a client.
Now, implement the NumPyClient
:
import flwr as fl
import numpy as np
# Assume create_simple_cnn and load_partition are defined as above
class MNISTClient(fl.client.NumPyClient):
def __init__(self, client_id, num_clients):
self.client_id = client_id
self.num_clients = num_clients
self.model = create_simple_cnn()
# Load this client's specific partition
(self.x_train, self.y_train), (self.x_test, self.y_test) = load_partition(
self.client_id, self.num_clients
)
def get_parameters(self, config):
# Return model weights as a list of NumPy arrays
return self.model.get_weights()
def set_parameters(self, parameters):
# Update model weights
self.model.set_weights(parameters)
def fit(self, parameters, config):
# Update model with received parameters
self.set_parameters(parameters)
# Train the model on local data
# Read training configuration from `config` dict if needed
local_epochs = config.get("local_epochs", 1)
batch_size = config.get("batch_size", 32)
history = self.model.fit(
self.x_train, self.y_train,
epochs=local_epochs, batch_size=batch_size, verbose=0 # Set verbose=2 for detailed logs
)
# Return updated weights, number of training examples, and optional metrics
results = {
"loss": history.history["loss"][0],
"accuracy": history.history["accuracy"][0],
}
return self.get_parameters(config={}), len(self.x_train), results
def evaluate(self, parameters, config):
# Update model with received parameters
self.set_parameters(parameters)
# Evaluate the model on local test data
loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
# Return loss, number of evaluation examples, and metrics
return loss, len(self.x_test), {"accuracy": accuracy}
# Function to instantiate clients based on ID
def client_fn(cid: str) -> fl.client.Client:
# cid is a string, convert to int if needed for partitioning logic
client_id = int(cid)
num_total_clients = 10 # Example: Total number of clients in the pool
return MNISTClient(client_id=client_id, num_clients=num_total_clients)
Note: The load_partition
function is crucial for simulating data heterogeneity. For a basic IID setup, it might simply divide the shuffled MNIST dataset equally. For Non-IID simulations (as discussed in Chapter 4), it would implement more complex partitioning, like distributing data based on digit labels.
start_simulation
)The server orchestrates the FL process. We define a strategy (like FedAvg
) and use start_simulation
to run the process with our client function.
import flwr as fl
# Define an aggregation strategy (e.g., FedAvg)
# We can customize FedAvg, for example, setting minimum clients for training/evaluation
strategy = fl.server.strategy.FedAvg(
fraction_fit=1.0, # Sample 100% of available clients for training
min_fit_clients=5, # Minimum number of clients to wait for in federated training
fraction_evaluate=0.5, # Sample 50% of available clients for evaluation
min_evaluate_clients=3, # Minimum number of clients for federated evaluation
min_available_clients=5, # Minimum number of clients available before a round starts
# We can also pass functions to customize server-side evaluation
# evaluate_fn=get_evaluate_fn(server_model), # Optional: Centralized evaluation on server
)
# Server-side configuration function for clients (optional)
def get_on_fit_config_fn():
def fit_config(server_round: int):
# Pass round-specific configuration to clients
config = {
"server_round": server_round,
"local_epochs": 2, # Example: train for 2 epochs locally
"batch_size": 32
}
return config
return fit_config
# Start the simulation
NUM_ROUNDS = 5
TOTAL_CLIENTS = 10
history = fl.simulation.start_simulation(
client_fn=client_fn, # Function to create clients
num_clients=TOTAL_CLIENTS, # Total number of clients available
config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), # Number of rounds
strategy=strategy, # Aggregation strategy
client_resources={"num_cpus": 1, "num_gpus": 0.0}, # Resources per client (adjust if using GPUs)
# Optional: Provide configuration function for fit/evaluate
# on_fit_config_fn=get_on_fit_config_fn(),
)
# The 'history' object contains metrics collected during simulation
print("Simulation finished.")
print("History (losses distributed):", history.losses_distributed)
print("History (metrics distributed):", history.metrics_distributed)
# Access centralized metrics if an evaluate_fn was provided on the server strategy
# print("History (losses centralized):", history.losses_centralized)
# print("History (metrics centralized):", history.metrics_centralized)
A simple IID partitioning function load_partition
might look like this:
import numpy as np
import tensorflow as tf
def load_partition(client_id, num_clients):
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# Preprocess: normalize and reshape
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
# Simple IID partition: shuffle and divide
# Ensure data is shuffled consistently if needed, or shuffle here
# np.random.seed(42) # For reproducibility
# shuffle_indices = np.random.permutation(len(x_train))
# x_train, y_train = x_train[shuffle_indices], y_train[shuffle_indices]
partition_size_train = len(x_train) // num_clients
start_train = client_id * partition_size_train
end_train = start_train + partition_size_train
x_train_cid, y_train_cid = x_train[start_train:end_train], y_train[start_train:end_train]
partition_size_test = len(x_test) // num_clients
start_test = client_id * partition_size_test
end_test = start_test + partition_size_test
x_test_cid, y_test_cid = x_test[start_test:end_test], y_test[start_test:end_test]
return (x_train_cid, y_train_cid), (x_test_cid, y_test_cid)
Important: This partitioning is basic. Real-world scenarios often involve significantly more complex, Non-IID distributions which require careful handling during simulation setup (refer to Chapter 4).
To run this, save the client implementation, server setup, and data loading logic into a single Python script (e.g., run_simulation.py
). Then execute it from your terminal:
python run_simulation.py
You should see output logs from Flower indicating the start and end of each federation round, client training (fit
), and client evaluation (evaluate
). The server will aggregate the results according to the FedAvg
strategy.
The history
object returned by start_simulation
contains valuable information. history.losses_distributed
will show the average loss reported by clients during evaluation phases across rounds. history.metrics_distributed
will contain aggregated metrics, like accuracy.
You can use this data to plot the learning progress:
import matplotlib.pyplot as plt
# Example: Plotting distributed evaluation accuracy
rounds = [r for r, _ in history.metrics_distributed["accuracy"]]
accuracies = [acc for _, acc in history.metrics_distributed["accuracy"]]
plt.figure(figsize=(8, 5))
plt.plot(rounds, accuracies, marker='o', color='#228be6') # Blue color
plt.title("Federated Evaluation Accuracy")
plt.xlabel("Federation Round")
plt.ylabel("Accuracy")
plt.grid(True)
plt.xticks(rounds)
plt.ylim(0, 1) # Accuracy ranges from 0 to 1
plt.show()
Let's visualize this potential outcome with a Plotly chart showing accuracy improving over rounds.
Example plot showing typical improvement in distributed evaluation accuracy across federated learning rounds.
This basic simulation serves as a starting point. Using the Flower framework and the concepts from previous chapters, you can extend this setup to explore more advanced scenarios:
FedAvg
with strategies like FedProx
or SCAFFOLD
(Chapter 2). Flower allows implementing custom strategies.fit
method or use privacy-preserving strategies (Chapter 3).client_resources
or simulate varying computation times within the fit
method (Chapter 4).This practical exercise demonstrates how FL frameworks simplify the process of simulating complex distributed learning systems, enabling rapid prototyping and evaluation of the advanced techniques covered in this course.
© 2025 ApX Machine Learning