Classes
Aggregator
Aggregator()
Defines methods for aggregating strategy (eg FedAvg, FedProx, SCAFFOLD, ...).
Source code in fedbiomed/researcher/aggregators/aggregator.py
def __init__(self):
self._aggregator_args: dict = None
self._fds: FederatedDataSet = None
self._training_plan_type: TrainingPlans = None
self._secagg_crypter = SecaggCrypter()
Functions
FedAverage
FedAverage()
Bases: Aggregator
Defines the Federated averaging strategy
Source code in fedbiomed/researcher/aggregators/fedavg.py
def __init__(self):
"""Construct `FedAverage` object as an instance of [`Aggregator`]
[fedbiomed.researcher.aggregators.Aggregator].
"""
super(FedAverage, self).__init__()
self.aggregator_name = "FedAverage"
Attributes
Functions
Scaffold
Scaffold(server_lr=1.0, fds=None)
Bases: Aggregator
Defines the Scaffold strategy
Despite being an algorithm of choice for federated learning, it is observed that FedAvg suffers from client-drift
when the data is heterogeneous (non-iid), resulting in unstable and slow convergence. SCAFFOLD uses control variates (variance reduction) to correct for the client-drift
in its local updates. Intuitively, SCAFFOLD estimates the update direction for the server model (c) and the update direction for each client (c_i). The difference (c - c_i) is then an estimate of the client-drift which is used to correct the local update.
Fed-BioMed implementation details
Our implementation is heavily influenced by our design choice to prevent storing any state on the nodes between FL rounds. In particular, this means that the computation of the control variates (i.e. the correction states) needs to be performed centrally by the aggregator. Roughly, our implementation follows these steps (following the notation of the original Scaffold paper):
- let \(\delta_i = \mathbf{c}_i - \mathbf{c} \)
- foreach(round):
- sample \( S \) nodes participating in this round out of \( N \) total
- the server communicates the global model \( \mathbf{x} \) and the correction states \( \delta_i \) to all clients
- parallel on each client
- initialize local model \( \mathbf{y}_i = \mathbf{x} \)
- foreach(update) until K updates have been performed
- obtain a data batch
- compute the gradients for this batch \( g(\mathbf{y}_i) \)
- apply correction term to gradients \( g(\mathbf{y}_i) -= \delta_i \)
- update model with one optimizer step e.g. for SGD \( \mathbf{y}_i -= \eta_i g(\mathbf{y}_i) \)
- end foreach(update)
- communicate updated model \( \mathbf{y}_i \) and learning rate \( \eta_i \)
- end parallel section on each client
- the server computes the node-wise model update \( \mathbf{\Delta y}_i = \mathbf{x} - \mathbf{y}_i \)
- the server updates the node-wise states \( \mathbf{c}_i = \delta_i + (\mathbf{\Delta y}_i) / (\eta_i K) \)
- the server updates the global state \( \mathbf{c} = (1/N) \sum_{i \in N} \mathbf{c}_i \)
- the server updates the node-wise correction state \(\delta_i = \mathbf{c}_i - \mathbf{c} \)
- the server updates the global model by averaging \( \mathbf{x} = \mathbf{x} - (\eta/|S|) \sum_{i \in S} \mathbf{\Delta y}_i \)
- end foreach(round)
This diagram provides a visual representation of the algorithm.
References:
- Scaffold: Stochastic Controlled Averaging for Federated Learning
- TCT: Convexifying Federated Learning using Bootstrapped Neural Tangent Kernels
Attributes:
Name | Type | Description |
---|---|---|
aggregator_name | str | name of the aggregator |
server_lr | float | value of the server learning rate |
global_state | Dict[str, Union[Tensor, ndarray]] | a dictionary representing the global correction state \( \mathbf{c} \) in the format {parameter name: correction value} |
nodes_states | Dict[str, Dict[str, Union[Tensor, ndarray]]] | a nested dictionary of correction parameters obtained for each client, in the format {node id: node-wise corrections}. The node-wise corrections are a dictionary in the format {parameter name: correction value} where the model parameters are those contained in each node's model.named_parameters(). |
nodes_deltas | Dict[str, Dict[str, Union[Tensor, ndarray]]] | a nested dictionary of deltas for each client, in the same format as nodes_states. The deltas are defined as \(\delta_i = \mathbf{c}_i - \mathbf{c} \) |
nodes_lr | Dict[str, Dict[str, float]] | dictionary of learning rates observed at end of the latest round, in the format {node id: learning rate} |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
server_lr | float | server's (or Researcher's) learning rate. Defaults to 1.. | 1.0 |
fds | FederatedDataset | FederatedDataset obtained after a | None |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def __init__(self, server_lr: float = 1., fds: Optional[FederatedDataSet] = None):
"""Constructs `Scaffold` object as an instance of [`Aggregator`]
[fedbiomed.researcher.aggregators.Aggregator].
Args:
server_lr (float): server's (or Researcher's) learning rate. Defaults to 1..
fds (FederatedDataset, optional): FederatedDataset obtained after a `search` request. Defaults to None.
"""
super().__init__()
self.aggregator_name: str = "Scaffold"
if server_lr == 0.:
raise FedbiomedAggregatorError("SCAFFOLD Error: Server learning rate cannot be equal to 0")
self.server_lr: float = server_lr
self.global_state: Dict[str, Union[torch.Tensor, np.ndarray]] = {}
self.nodes_states: Dict[str, Dict[str, Union[torch.Tensor, np.ndarray]]] = {}
self.nodes_deltas: Dict[str, Dict[str, Union[torch.Tensor, np.ndarray]]] = {}
self.nodes_lr: Dict[str, Dict[str, float]] = {}
if fds is not None:
self.set_fds(fds)
self._aggregator_args = {} # we need `_aggregator_args` to be not None
Attributes
Functions
Functions
federated_averaging
federated_averaging(model_params, weights)
Defines Federated Averaging (FedAvg) strategy for model aggregation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_params | List[Dict[str, Union[Tensor, ndarray]]] | list that contains nodes' model parameters; each model is stored as an OrderedDict (maps model layer name to the model weights) | required |
weights | List[float] | weights for performing weighted sum in FedAvg strategy (depending on the dataset size of each node). Items in the list must always sum up to 1 | required |
Returns:
Type | Description |
---|---|
Mapping[str, Union[Tensor, ndarray]] | Final model with aggregated layers, as an OrderedDict object. |
Source code in fedbiomed/researcher/aggregators/functional.py
def federated_averaging(model_params: List[Dict[str, Union[torch.Tensor, np.ndarray]]],
weights: List[float]) -> Mapping[str, Union[torch.Tensor, np.ndarray]]:
"""Defines Federated Averaging (FedAvg) strategy for model aggregation.
Args:
model_params: list that contains nodes' model parameters; each model is stored as an OrderedDict (maps
model layer name to the model weights)
weights: weights for performing weighted sum in FedAvg strategy (depending on the dataset size of each node).
Items in the list must always sum up to 1
Returns:
Final model with aggregated layers, as an OrderedDict object.
"""
assert len(model_params) > 0, 'An empty list of models was passed.'
assert len(weights) == len(model_params), 'List with number of observations must have ' \
'the same number of elements that list of models.'
# Compute proportions
proportions = [n_k / sum(weights) for n_k in weights]
return weighted_sum(model_params, proportions)
initialize
initialize(val)
Initialize tensor or array vector.
Source code in fedbiomed/researcher/aggregators/functional.py
def initialize(val: Union[torch.Tensor, np.ndarray]) -> Tuple[str, Union[torch.Tensor, np.ndarray]]:
"""Initialize tensor or array vector. """
if isinstance(val, torch.Tensor):
return 'tensor', torch.zeros_like(val).float()
if isinstance(val, (list, np.ndarray)):
val = np.array(val)
return 'array', np.zeros(val.shape, dtype = float)
weighted_sum
weighted_sum(model_params, proportions)
Performs weighted sum operation
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_params | List[Dict[str, Union[Tensor, ndarray]]] | list that contains nodes' model parameters; each model is stored as an OrderedDict (maps model layer name to the model weights) | required |
proportions | List[float] | weights of all items whithin model_params's list | required |
Returns:
Type | Description |
---|---|
Mapping[str, Union[Tensor, ndarray]] | Mapping[str, Union[torch.Tensor, np.ndarray]]: model resulting from the weigthed sum operation |
Source code in fedbiomed/researcher/aggregators/functional.py
def weighted_sum(model_params: List[Dict[str, Union[torch.Tensor, np.ndarray]]],
proportions: List[float]) -> Mapping[str, Union[torch.Tensor, np.ndarray]]:
"""Performs weighted sum operation
Args:
model_params (List[Dict[str, Union[torch.Tensor, np.ndarray]]]): list that contains nodes'
model parameters; each model is stored as an OrderedDict (maps model layer name to the model weights)
proportions (List[float]): weights of all items whithin model_params's list
Returns:
Mapping[str, Union[torch.Tensor, np.ndarray]]: model resulting from the weigthed sum
operation
"""
# Empty model parameter dictionary
avg_params = copy.deepcopy(model_params[0])
for key, val in avg_params.items():
(t, avg_params[key] ) = initialize(val)
if t == 'tensor':
for model, weight in zip(model_params, proportions):
for key in avg_params.keys():
avg_params[key] += weight * model[key]
if t == 'array':
for key in avg_params.keys():
matr = np.array([ d[key] for d in model_params ])
avg_params[key] = np.average(matr, weights=np.array(proportions), axis=0)
return avg_params