PyTorch aggregation methods in Fed-BioMed¶
Difficulty level: advanced
Introduction¶
This tutorial focuses on how to deal with heterogeneous datasets by changing its Aggregator. Fed-BioMed provides different methods for Aggregation. Selecting an appropriate Aggregation method can be critical when being confronted to unbalanced or heterogeneous datasets.
Aggregators provide a way to merge local models sent by Nodes into a global, more generalized model. Please note that designing Node sampling Strategies can also help when working on heterogeneous datasets.
For more information about Aggregators object in Fed-BioMed, and on how to create your own Aggregator; please see Aggregators in the User Guide
Mednist Dataset¶
For this tutorial, we will be using heterogeneous the MedNIST dataset. MedNIST is a collection of 2-D grayscale medical images. The MedNIST dataset was gathered from several sets from TCIA, the RSNA Bone Age Challenge, and the NIH Chest X-ray dataset. The dataset is kindly made available by Dr. Bradley J. Erickson M.D., Ph.D. (Department of Radiology, Mayo Clinic) under the Creative Commons CC BY-SA 4.0 license and is distributed by MONAI for teaching and benchmarking simple deep-learning pipelines. For more information regarding the dataset please see MedNIST Dataset.
Before you start¶
Make sure that you have configured your nodes for training. For configuration, we create our node via
fedbiomed node --path path/to/your/node create
And after creation, add the MedNist Dataset:
fedbiomed node dataset add
This command will give you a menu with dataset options similar to below:
1) csv
2) default
3) mednist
4) images
5) medical-folder
6) custom Choose the 3rd option to add the MedNIST dataset. You can select y to add with the default tags. And for the path select the folder where you want to download (or have downloaded) the dataset.
Nota: Tags are important in FedBioMed. They are used as identifiers to select the datasets in the nodes, that will be used for training. Make sure that you use the same tags when adding a dataset, and defining an experiment (which will be seen below soon).
For more information regarding Node Configuration, please refer to the User Guide.
Defining an Experiment using FedAverage Aggregator¶
In this example, we reuse the TrainingPlan that was defined in the previous MedNIST tutorial. It uses a pretrained DenseNet model, where only the classifier is changed in order to adapt it to our task. We normalize the MedNIST dataset before feeding it to our model.
The only change from the previous tutorial is going to be the Aggregators. We are going to show FedAveraging, FedProx and Scaffold Aggregation. We will start with FedAveraging. FedAveraging has been introduced by McMahan et al. as the first aggregation method in the Federated Learning literature. It does the weighted sum of all Nodes local models parameters in order to obtain a global model:
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.datamanager import DataManager
from fedbiomed.common.dataset import MedNistDataset
import torch
import torch.nn as nn
class MyTrainingPlan(TorchTrainingPlan):
def init_model(self, model_args):
model = models.densenet121(weights=None) # here model coefficients are set to random weights
# add the classifier
num_classes = model_args['num_classes']
num_ftrs = model.classifier.in_features
model.classifier= nn.Sequential(
nn.Linear(num_ftrs, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(512, num_classes)
)
return model
def init_dependencies(self):
return [
"from torchvision import transforms, models",
"import torch.optim as optim",
"from torchvision.models import densenet121",
"from fedbiomed.common.dataset import MedNistDataset"
]
def init_optimizer(self, optimizer_args):
return optim.Adam(self.model().parameters(), lr=optimizer_args["lr"])
def training_data(self):
# Transform images and do data augmentation
preprocess = transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
target_transform = transforms.Lambda(lambda y: y.long())
train_data = MedNistDataset(transform = preprocess, target_transform=target_transform)
train_kwargs = { 'shuffle': True}
return DataManager(dataset=train_data, **train_kwargs)
def training_step(self, data, target):
output = self.model().forward(data)
loss_func = nn.CrossEntropyLoss()
loss = loss_func(output, target)
return loss
We define hereafter parameters for Experiment to be used with vanilla FedAverage
training_args = {
'loader_args': {
'batch_size': 32,
},
'random_seed': 1234,
'optimizer_args': {'lr': 1e-3},
'epochs': 1,
'dry_run': False,
'num_updates': 50, # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}
model_args = {
'num_classes': 6, # adapt this number to the number of classes in your dataset
}
We then import FedAverage Aggregator from Fed-BioMed's Aggregators
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators import FedAverage
from fedbiomed.researcher.strategies.default_strategy import DefaultStrategy
tags = ['#MEDNIST', '#dataset']
rounds = 3
exp_fed_avg = Experiment()
exp_fed_avg.set_model_args(model_args=model_args)
exp_fed_avg.set_training_args(training_args=training_args)
exp_fed_avg.set_training_plan_class(training_plan_class=MyTrainingPlan)
exp_fed_avg.set_tags(tags = tags)
exp_fed_avg.set_training_data(training_data=None, from_tags=True)
exp_fed_avg.set_aggregator(aggregator=FedAverage())
exp_fed_avg.set_strategy(node_selection_strategy=DefaultStrategy())
exp_fed_avg.set_round_limit(rounds)
exp_fed_avg.set_tensorboard(True)
Activate Tensorboard
%load_ext tensorboard
fedavg_tensorboard_dir = exp_fed_avg.tensorboard_results_path
%tensorboard --logdir {fedavg_tensorboard_dir}
exp_fed_avg.run(increase=True)
Save trained model to file
exp_fed_avg.training_plan().export_model('./trained_model')
2. Defining an Experiment using FedProx Aggregator¶
The second aggregator we show is going to be FedProx. FedProx is a modification of FedAverage that adds a regularization term to the local training objective, which prevents the model from deviating too far from the global model. This helps improve convergence and handles non-IID (non-independent and identically distributed) data.
To implement it in FedBioMed, it is sufficient to just add the regularization parameter fedprox_mu into the training_args.
training_args_fedprox = {
'loader_args': {
'batch_size': 32,
},
'random_seed': 1234,
'optimizer_args': {'lr': 1e-3},
'epochs': 1,
'dry_run': False,
'num_updates': 50,
'fedprox_mu': .1, # This parameter indicates that we are going to use FedProx
}
model_args = {
'num_classes': 6, # adapt this number to the number of classes in your dataset
}
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators import FedAverage
from fedbiomed.researcher.strategies.default_strategy import DefaultStrategy
tags = ['#MEDNIST', '#dataset']
rounds = 3
exp_fedprox = Experiment()
exp_fedprox.set_model_args(model_args=model_args)
exp_fedprox.set_training_args(training_args=training_args_fedprox)
exp_fedprox.set_training_plan_class(training_plan_class=MyTrainingPlan)
exp_fedprox.set_tags(tags = tags)
exp_fedprox.set_training_data(training_data=None, from_tags=True)
exp_fedprox.set_aggregator(aggregator=FedAverage())
exp_fedprox.set_strategy(node_selection_strategy=DefaultStrategy())
exp_fedprox.set_round_limit(rounds)
exp_fedprox.set_tensorboard(True)
%reload_ext tensorboard
fedprox_tensorboard_dir = exp_fedprox.tensorboard_results_path
%tensorboard --logdir {fedavg_tensorboard_dir}
exp_fedprox.run(increase=True)
Save trained model to file
exp_fedprox.training_plan().export_model('./trained_model')
Defining an Experiment using SCAFFOLD Aggregator¶
In traditional federated learning algorithms like FedAvg, each client trains a local model using its own data. However, since clients have different data distributions, the local models may deviate significantly from the global model, causing slow convergence and instability in the aggregation process. This problem is addressed as Client Drift.
Scaffold introduces a set of control variates (or auxiliary variables) that help track the differences between local and global updates. These control variates are maintained for each client and used to adjust the local gradients during the training process.
In each training round t, the control variate for client k ( c^{(k)}_t ) is updated as:
$$ c^{(k)}_{t+1} = c^{(k)}_t + g^{(k)}_t - \bar{g}_t $$
Where ( g^{(k)}_t ) is the gradient for client ( k ), and ( \bar{g}_t ) is the global average of the gradients.
The updated local gradient (corrected update) is given by:
$$ \hat{g}^{(k)}_t = g^{(k)}_t - c^{(k)}_t $$
By reducing client drift, Scaffold provides a more stable convergence and better generalization.
To use Scaffold in Fedbiomed, we import another it from the fedbiomed.researcher.aggregators module.
Scaffold aggregator takes two arguments:
server_lris the Server Learning Rate (the gradient descent on the global model after receiving each corrected update from the nodes)fdsis theFederated Datasetcontaining information about theNodesconnected to the network after issuing aTrainRequest
Please note that it is also possible to use Scaffold with a regularization parameter as suggested in FedProx. For that, you just have to specify fedprox_mu into the training_args dictionary, as shown in the FedProx example
Attention: this version of Scaffold exchanges correction terms that are not protected, even when using Secure Aggregation. Please do not use this version of Scaffold under heavy security constraints.
from fedbiomed.researcher.aggregators import Scaffold
from fedbiomed.researcher.strategies.default_strategy import DefaultStrategy
server_lr = .8
exp_scaffold = Experiment()
exp_scaffold.set_model_args(model_args=model_args)
exp_scaffold.set_training_args(training_args=training_args)
exp_scaffold.set_training_plan_class(training_plan_class=MyTrainingPlan)
exp_scaffold.set_tags(tags = tags)
exp_scaffold.set_training_data(training_data=None, from_tags=True)
exp_scaffold.set_aggregator(Scaffold(server_lr=server_lr))
exp_scaffold.set_strategy(node_selection_strategy=DefaultStrategy())
exp_scaffold.set_round_limit(rounds)
exp_scaffold.set_tensorboard(True)
%reload_ext tensorboard
scaffold_tensorboard_dir = exp_scaffold.tensorboard_results_path
%tensorboard --logdir {fedavg_tensorboard_dir}
exp_scaffold.run(increase=True)
Save trained model to file
exp_scaffold.training_plan().export_model('./trained_model')
4. Going further¶
In this tutorial we presented 3 important Aggregators that can be found in the Federated Learning Literature. If you want to create your custom Aggregator, please check our Aggregation User guide
You may have noticed that thanks to Fed-BioMed's modular structure, it is possible to alternate from one aggregator to another while conducting an Experiment. For instance, you may start with the SCAFFOLD Aggregator for the 3 first rounds, and then switch to FedAverage Aggregator for the remaining rounds, as shown in the example below:
from fedbiomed.researcher.aggregators import Scaffold, FedAverage
from fedbiomed.researcher.strategies.default_strategy import DefaultStrategy
server_lr = .8
exp_multi_agg = Experiment()
# selecting how many rounds of each aggregator we will perform
rounds_scaffold = 3
rounds_fedavg = 1
exp_multi_agg.set_model_args(model_args=model_args)
exp_multi_agg.set_training_args(training_args=training_args)
exp_multi_agg.set_training_plan_class(training_plan_class=MyTrainingPlan)
exp_multi_agg.set_tags(tags = tags)
exp_multi_agg.set_training_data(training_data=None, from_tags=True)
exp_multi_agg.set_aggregator(Scaffold(server_lr=server_lr))
exp_multi_agg.set_strategy(node_selection_strategy=DefaultStrategy())
exp_multi_agg.set_round_limit(rounds_scaffold + rounds_fedavg)
exp_multi_agg.run(rounds=rounds_scaffold)
exp_multi_agg.set_aggregator(FedAverage())
exp_multi_agg.run(rounds=rounds_fedavg)
Save trained model to file
exp_multi_agg.training_plan().export_model('./trained_model')
For more advanced Aggregators and Regularizers, like FedOpt, you may be interested by DecLearn optimizers that are compatible with Fed-BioMed and provide more options for Aggregation and Optimization.