Classes
Aggregator
Aggregator()
Defines methods for aggregating strategy (eg FedAvg, FedProx, SCAFFOLD, ...).
Source code in fedbiomed/researcher/aggregators/aggregator.py
def __init__(self):
self._aggregator_args: dict = None
self._fds: FederatedDataSet = None
self._training_plan_type: TrainingPlans = None
self._secagg_crypter = SecaggCrypter()
Functions
aggregate
aggregate(model_params, weights, *args, **kwargs)
Strategy to aggregate models
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_params | list | List of model parameters received from each node | required |
weights | list | Weight for each node-model-parameter set | required |
Raises:
Type | Description |
---|---|
FedbiomedAggregatorError | If the method is not defined by inheritor |
Source code in fedbiomed/researcher/aggregators/aggregator.py
def aggregate(self, model_params: list, weights: list, *args, **kwargs) -> Dict:
"""
Strategy to aggregate models
Args:
model_params: List of model parameters received from each node
weights: Weight for each node-model-parameter set
Raises:
FedbiomedAggregatorError: If the method is not defined by inheritor
"""
msg = ErrorNumbers.FB401.value + \
": aggregate method should be overloaded by the choosen strategy"
logger.critical(msg)
raise FedbiomedAggregatorError(msg)
check_values
check_values(*args, **kwargs)
Source code in fedbiomed/researcher/aggregators/aggregator.py
def check_values(self, *args, **kwargs) -> True:
return True
create_aggregator_args
create_aggregator_args(*args, **kwargs)
Returns aggregator arguments that are expecting by the nodes
Parameters:
Name | Type | Description | Default |
---|---|---|---|
args | ignored | () | |
kwargs | ignored | {} |
Returns:
Type | Description |
---|---|
Dict | contains |
Source code in fedbiomed/researcher/aggregators/aggregator.py
def create_aggregator_args(self, *args, **kwargs) -> Dict:
"""Returns aggregator arguments that are expecting by the nodes
Args:
args: ignored
kwargs: ignored
Returns:
contains `Aggregator` parameters/argument that will be shared with the nodes
"""
return self._aggregator_args or {}
load_state_breakpoint
load_state_breakpoint(state, **kwargs)
use for breakpoints. load the aggregator state
Source code in fedbiomed/researcher/aggregators/aggregator.py
def load_state_breakpoint(self, state: Dict[str, Any], **kwargs) -> None:
"""
use for breakpoints. load the aggregator state
"""
if not isinstance(state["parameters"], Dict):
self._aggregator_args = Serializer.load(state['parameters'])
else:
self._aggregator_args = state['parameters']
save_state_breakpoint
save_state_breakpoint(breakpoint_path=None, **aggregator_args_create)
use for breakpoints. save the aggregator state
Source code in fedbiomed/researcher/aggregators/aggregator.py
def save_state_breakpoint(
self,
breakpoint_path: Optional[str] = None,
**aggregator_args_create: Any,
) -> Dict[str, Any]:
"""
use for breakpoints. save the aggregator state
"""
aggregator_args = self.create_aggregator_args(**aggregator_args_create)
if aggregator_args:
if self._aggregator_args is None:
self._aggregator_args = {}
self._aggregator_args.update(aggregator_args)
if breakpoint_path:
filename = self._save_arg_to_file(breakpoint_path, 'aggregator_args', uuid.uuid4(), self._aggregator_args)
state = {
"class": type(self).__name__,
"module": self.__module__,
"parameters": filename if breakpoint_path else self._aggregator_args
}
return state
secure_aggregation
secure_aggregation(params, encryption_factors, secagg_random, aggregation_round, total_sample_size, training_plan)
Apply aggregation for encrypted model parameters
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params | List[List[int]] | List containing list of encrypted parameters of each node | required |
encryption_factors | List[Dict[str, List[int]]] | List of encrypted integers to validate encryption | required |
secagg_random | float | Randomly generated float value to validate secure aggregation correctness | required |
aggregation_round | int | The round of the aggregation. | required |
total_sample_size | int | Sum of sample sizes used for training | required |
training_plan | BaseTrainingPlan | Training plan instance used for the training. | required |
Returns:
Type | Description |
---|---|
aggregated model parameters |
Source code in fedbiomed/researcher/aggregators/aggregator.py
def secure_aggregation(
self,
params: List[List[int]],
encryption_factors: List[Dict[str, List[int]]],
secagg_random: float,
aggregation_round: int,
total_sample_size: int,
training_plan: 'BaseTrainingPlan'
):
""" Apply aggregation for encrypted model parameters
Args:
params: List containing list of encrypted parameters of each node
encryption_factors: List of encrypted integers to validate encryption
secagg_random: Randomly generated float value to validate secure aggregation correctness
aggregation_round: The round of the aggregation.
total_sample_size: Sum of sample sizes used for training
training_plan: Training plan instance used for the training.
Returns:
aggregated model parameters
"""
# TODO: verify with secagg context number of parties
num_nodes = len(params)
# TODO: Use server key here
key = -(len(params) * 10)
# IMPORTANT = Keep this key for testing purposes
key = -4521514305280526329525552501850970498079782904248225896786295610941010325354834129826500373412436986239012584207113747347251251180530850751209537684586944643780840182990869969844131477709433555348941386442841023261287875379985666260596635843322044109172782411303407030194453287409138194338286254652273563418119335656859169132074431378389356392955315045979603414700450628308979043208779867835835935403213000649039155952076869962677675951924910959437120608553858253906942559260892494214955907017206115207769238347962438107202114814163305602442458693305475834199715587932463252324681290310458316249381037969151400784780
logger.info("Securely aggregating model parameters...")
aggregate = functools.partial(self._secagg_crypter.aggregate,
current_round=aggregation_round,
num_nodes=num_nodes,
key=key,
total_sample_size=total_sample_size
)
# Validation
encryption_factors = [f for k, f in encryption_factors.items()]
validation: List[int] = aggregate(params=encryption_factors)
if len(validation) != 1 or not math.isclose(validation[0], secagg_random, abs_tol=0.01):
raise FedbiomedAggregatorError("Aggregation is failed due to incorrect decryption.")
aggregated_params = aggregate(params=params)
# Convert model params
model = training_plan.get_model_wrapper_class()
model_params = model.unflatten(aggregated_params)
return model_params
set_fds
set_fds(fds)
Source code in fedbiomed/researcher/aggregators/aggregator.py
def set_fds(self, fds: FederatedDataSet) -> FederatedDataSet:
self._fds = fds
return self._fds
set_training_plan_type
set_training_plan_type(training_plan_type)
Source code in fedbiomed/researcher/aggregators/aggregator.py
def set_training_plan_type(self, training_plan_type: TrainingPlans) -> TrainingPlans:
self._training_plan_type = training_plan_type
return self._training_plan_type
FedAverage
FedAverage()
Bases: Aggregator
Defines the Federated averaging strategy
Source code in fedbiomed/researcher/aggregators/fedavg.py
def __init__(self):
"""Construct `FedAverage` object as an instance of [`Aggregator`]
[fedbiomed.researcher.aggregators.Aggregator].
"""
super(FedAverage, self).__init__()
self.aggregator_name = "FedAverage"
Attributes
aggregator_name instance-attribute
aggregator_name = 'FedAverage'
Functions
aggregate
aggregate(model_params, weights, *args, **kwargs)
Aggregates local models sent by participating nodes into a global model, following Federated Averaging strategy.
weights is a list of single-item dictionaries, each dictionary has the node id as key, and the weight as value. model_params is a list of single-item dictionaries, each dictionary has the node is as key, and a framework-specific representation of the model parameters as value.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_params | Dict[str, Dict[str, Union[Tensor, ndarray]]] | contains each model layers | required |
weights | Dict[str, float] | contains all weights of a given layer. | required |
Returns:
Type | Description |
---|---|
Mapping[str, Union[Tensor, ndarray]] | Aggregated parameters |
Source code in fedbiomed/researcher/aggregators/fedavg.py
def aggregate(
self,
model_params: Dict[str, Dict[str, Union['torch.Tensor', 'numpy.ndarray']]],
weights: Dict[str, float],
*args,
**kwargs
) -> Mapping[str, Union['torch.Tensor', 'numpy.ndarray']]:
""" Aggregates local models sent by participating nodes into a global model, following Federated Averaging
strategy.
weights is a list of single-item dictionaries, each dictionary has the node id as key, and the weight as value.
model_params is a list of single-item dictionaries, each dictionary has the node is as key,
and a framework-specific representation of the model parameters as value.
Args:
model_params: contains each model layers
weights: contains all weights of a given layer.
Returns:
Aggregated parameters
"""
model_params_processed = []
weights_processed = []
for node_id, params in model_params.items():
if node_id not in weights:
raise FedbiomedAggregatorError(
f"{ErrorNumbers.FB401.value}. Can not find corresponding calculated weight for the "
f"node {node_id}. Aggregation is aborted."
)
weight = weights[node_id]
model_params_processed.append(params)
weights_processed.append(weight)
if any([x < 0. or x > 1. for x in weights_processed]) or sum(weights_processed) == 0:
raise FedbiomedAggregatorError(
f"{ErrorNumbers.FB401.value}. Aggregation aborted due to sum of the weights is equal to 0 {weights}. "
f"Sample sizes received from nodes might be corrupted."
)
agg_params = federated_averaging(model_params_processed, weights_processed)
return agg_params
Scaffold
Scaffold(server_lr=1.0, fds=None)
Bases: Aggregator
Defines the Scaffold strategy
Despite being an algorithm of choice for federated learning, it is observed that FedAvg suffers from client-drift
when the data is heterogeneous (non-iid), resulting in unstable and slow convergence. SCAFFOLD uses control variates (variance reduction) to correct for the client-drift
in its local updates. Intuitively, SCAFFOLD estimates the update direction for the server model (c) and the update direction for each client (c_i). The difference (c - c_i) is then an estimate of the client-drift which is used to correct the local update.
Fed-BioMed implementation details
Our implementation is heavily influenced by our design choice to prevent storing any state on the nodes between FL rounds. In particular, this means that the computation of the control variates (i.e. the correction states) needs to be performed centrally by the aggregator. Roughly, our implementation follows these steps (following the notation of the original Scaffold paper):
- let \(\delta_i = \mathbf{c}_i - \mathbf{c} \)
- foreach(round):
- sample \( S \) nodes participating in this round out of \( N \) total
- the server communicates the global model \( \mathbf{x} \) and the correction states \( \delta_i \) to all clients
- parallel on each client
- initialize local model \( \mathbf{y}_i = \mathbf{x} \)
- foreach(update) until K updates have been performed
- obtain a data batch
- compute the gradients for this batch \( g(\mathbf{y}_i) \)
- apply correction term to gradients \( g(\mathbf{y}_i) -= \delta_i \)
- update model with one optimizer step e.g. for SGD \( \mathbf{y}_i -= \eta_i g(\mathbf{y}_i) \)
- end foreach(update)
- communicate updated model \( \mathbf{y}_i \) and learning rate \( \eta_i \)
- end parallel section on each client
- the server computes the node-wise model update \( \mathbf{\Delta y}_i = \mathbf{x} - \mathbf{y}_i \)
- the server updates the node-wise states \( \mathbf{c}_i = \delta_i + (\mathbf{\Delta y}_i) / (\eta_i K) \)
- the server updates the global state \( \mathbf{c} = (1/N) \sum_{i \in N} \mathbf{c}_i \)
- the server updates the node-wise correction state \(\delta_i = \mathbf{c}_i - \mathbf{c} \)
- the server updates the global model by averaging \( \mathbf{x} = \mathbf{x} - (\eta/|S|) \sum_{i \in S} \mathbf{\Delta y}_i \)
- end foreach(round)
This diagram provides a visual representation of the algorithm.
References:
- Scaffold: Stochastic Controlled Averaging for Federated Learning
- TCT: Convexifying Federated Learning using Bootstrapped Neural Tangent Kernels
Attributes:
Name | Type | Description |
---|---|---|
aggregator_name | str | name of the aggregator |
server_lr | float | value of the server learning rate |
global_state | Dict[str, Union[Tensor, ndarray]] | a dictionary representing the global correction state \( \mathbf{c} \) in the format {parameter name: correction value} |
nodes_states | Dict[str, Dict[str, Union[Tensor, ndarray]]] | a nested dictionary of correction parameters obtained for each client, in the format {node id: node-wise corrections}. The node-wise corrections are a dictionary in the format {parameter name: correction value} where the model parameters are those contained in each node's model.named_parameters(). |
nodes_deltas | Dict[str, Dict[str, Union[Tensor, ndarray]]] | a nested dictionary of deltas for each client, in the same format as nodes_states. The deltas are defined as \(\delta_i = \mathbf{c}_i - \mathbf{c} \) |
nodes_lr | Dict[str, Dict[str, float]] | dictionary of learning rates observed at end of the latest round, in the format {node id: learning rate} |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
server_lr | float | server's (or Researcher's) learning rate. Defaults to 1.. | 1.0 |
fds | FederatedDataset | FederatedDataset obtained after a | None |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def __init__(self, server_lr: float = 1., fds: Optional[FederatedDataSet] = None):
"""Constructs `Scaffold` object as an instance of [`Aggregator`]
[fedbiomed.researcher.aggregators.Aggregator].
Args:
server_lr (float): server's (or Researcher's) learning rate. Defaults to 1..
fds (FederatedDataset, optional): FederatedDataset obtained after a `search` request. Defaults to None.
"""
super().__init__()
self.aggregator_name: str = "Scaffold"
if server_lr == 0.:
raise FedbiomedAggregatorError("SCAFFOLD Error: Server learning rate cannot be equal to 0")
self.server_lr: float = server_lr
self.global_state: Dict[str, Union[torch.Tensor, np.ndarray]] = {}
self.nodes_states: Dict[str, Dict[str, Union[torch.Tensor, np.ndarray]]] = {}
# FIXME: `nodes_states` is mis-named, because can conflict with `node_state`s that are saved
# whitin 2 Rounds
self.nodes_deltas: Dict[str, Dict[str, Union[torch.Tensor, np.ndarray]]] = {}
self.nodes_lr: Dict[str, Dict[str, float]] = {}
if fds is not None:
self.set_fds(fds)
self._aggregator_args = {} # we need `_aggregator_args` to be not None
Attributes
aggregator_name instance-attribute
aggregator_name = 'Scaffold'
global_state instance-attribute
global_state = {}
nodes_deltas instance-attribute
nodes_deltas = {}
nodes_lr instance-attribute
nodes_lr = {}
nodes_states instance-attribute
nodes_states = {}
server_lr instance-attribute
server_lr = server_lr
Functions
aggregate
aggregate(model_params, weights, global_model, training_plan, training_replies, n_updates=1, n_round=0, *args, **kwargs)
Aggregates local models coming from nodes into a global model, using SCAFFOLD algorithm (2nd option) [Scaffold: Stochastic Controlled Averaging for Federated Learning][https://arxiv.org/abs/1910.06378]
Performed computations:
- Compute participating nodes' model update:
- update_i = y_i - x
- Compute aggregated model parameters:
- x(+) = x - eta_g sum_S(update_i)
- Update participating nodes' state:
- c_i = delta_i + 1/(K*eta_i) * update_i
- Update the global state and all nodes' correction state:
- c = 1/N sum_{i=1}^n c_i
- delta_i = (c_i - c)
where, according to paper notations c_i: local state variable for node i
c: global state variable delta_i: (c_i - c), correction state for node i
eta_g: server's learning rate eta_i: node i's learning rate N: total number of node participating to federated learning S: number of nodes considered during current round (S<=N) K: number of updates done during the round (ie number of data batches). x: global model parameters y_i: node i 's local model parameters at the end of the round
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_params | Dict | list of models parameters received from nodes | required |
weights | Dict[str, float] | weights depicting sample proportions available on each node. Unused for Scaffold. | required |
global_model | Dict[str, Union[Tensor, ndarray]] | global model, ie aggregated model | required |
training_plan | BaseTrainingPlan | instance of TrainingPlan | required |
training_replies | Dict | Training replies from each node that participates in the current round | required |
n_updates | int | number of updates (number of batch performed). Defaults to 1. | 1 |
n_round | int | current round. Defaults to 0. | 0 |
Returns:
Type | Description |
---|---|
Dict | Aggregated parameters, as a dict mapping weight names and values. |
Raises:
Type | Description |
---|---|
FedbiomedAggregatorError | If no FederatedDataset is attached to this Scaffold instance, or if |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def aggregate(self,
model_params: Dict,
weights: Dict[str, float],
global_model: Dict[str, Union[torch.Tensor, np.ndarray]],
training_plan: BaseTrainingPlan,
training_replies: Dict,
n_updates: int = 1,
n_round: int = 0,
*args, **kwargs) -> Dict:
"""
Aggregates local models coming from nodes into a global model, using SCAFFOLD algorithm (2nd option)
[Scaffold: Stochastic Controlled Averaging for Federated Learning][https://arxiv.org/abs/1910.06378]
Performed computations:
-----------------------
- Compute participating nodes' model update:
* update_i = y_i - x
- Compute aggregated model parameters:
* x(+) = x - eta_g sum_S(update_i)
- Update participating nodes' state:
* c_i = delta_i + 1/(K*eta_i) * update_i
- Update the global state and all nodes' correction state:
* c = 1/N sum_{i=1}^n c_i
* delta_i = (c_i - c)
where, according to paper notations
c_i: local state variable for node `i`
c: global state variable
delta_i: (c_i - c), correction state for node `i`
eta_g: server's learning rate
eta_i: node i's learning rate
N: total number of node participating to federated learning
S: number of nodes considered during current round (S<=N)
K: number of updates done during the round (ie number of data batches).
x: global model parameters
y_i: node i 's local model parameters at the end of the round
Args:
model_params: list of models parameters received from nodes
weights: weights depicting sample proportions available
on each node. Unused for Scaffold.
global_model: global model, ie aggregated model
training_plan (BaseTrainingPlan): instance of TrainingPlan
training_replies: Training replies from each node that participates in the current round
n_updates: number of updates (number of batch performed). Defaults to 1.
n_round: current round. Defaults to 0.
Returns:
Aggregated parameters, as a dict mapping weight names and values.
Raises:
FedbiomedAggregatorError: If no FederatedDataset is attached to this
Scaffold instance, or if `node_ids` do not belong to the dataset
attached to it.
"""
# Gather the learning rates used by nodes, updating `self.nodes_lr`.
self.set_nodes_learning_rate_after_training(training_plan, training_replies)
# At round 0, initialize zero-valued correction states.
if n_round == 0:
self.init_correction_states(global_model)
# Check that the input node_ids match known ones.
if not set(model_params).issubset(self._fds.node_ids()):
raise FedbiomedAggregatorError(
"Received updates from nodes that are unknown to this aggregator."
)
# Compute the node-wise model update: (x^t - y_i^t).
model_updates = {
node_id: {
key: (global_model[key] - local_value)
for key, local_value in params.items()
}
for node_id, params in model_params.items()
}
# Update all Scaffold state variables.
self.update_correction_states(model_updates, n_updates)
# Compute and return the aggregated model parameters.
global_new = {} # type: Dict[str, Union[torch.Tensor, np.ndarray]]
for key, val in global_model.items():
upd = sum(model_updates[node_id][key] for node_id in model_params)
global_new[key] = val - upd * (self.server_lr / len(model_params))
return global_new
check_values
check_values(n_updates, training_plan)
Check if all values/parameters are correct and have been set before using aggregator.
Raise an error otherwise.
This can prove useful if user has set wrong hyperparameter values, so that user will have errors before performing first round of training
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_updates | int | number of updates. Must be non-zero and an integer. | required |
training_plan | BaseTrainingPlan | training plan. used for checking if optimizer is SGD, otherwise, triggers warning. | required |
Raises:
Type | Description |
---|---|
FedbiomedAggregatorError | triggered if |
FedbiomedAggregatorError | triggered if any of the learning rate(s) equals 0 |
FedbiomedAggregatorError | triggered if number of updates equals 0 or is not an integer |
FedbiomedAggregatorError | triggered if [FederatedDataset][fedbiomed.researcher.datasets.FederatedDataset] has not been set. |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def check_values(self, n_updates: int, training_plan: BaseTrainingPlan) -> True:
"""Check if all values/parameters are correct and have been set before using aggregator.
Raise an error otherwise.
This can prove useful if user has set wrong hyperparameter values, so that user will
have errors before performing first round of training
Args:
n_updates: number of updates. Must be non-zero and an integer.
training_plan: training plan. used for checking if optimizer is SGD, otherwise,
triggers warning.
Raises:
FedbiomedAggregatorError: triggered if `num_updates` entry is missing (needed for Scaffold aggregator)
FedbiomedAggregatorError: triggered if any of the learning rate(s) equals 0
FedbiomedAggregatorError: triggered if number of updates equals 0 or is not an integer
FedbiomedAggregatorError: triggered if [FederatedDataset][fedbiomed.researcher.datasets.FederatedDataset]
has not been set.
"""
if n_updates is None:
raise FedbiomedAggregatorError("Cannot perform Scaffold: missing 'num_updates' entry in the training_args")
elif n_updates <= 0 or int(n_updates) != float(n_updates):
raise FedbiomedAggregatorError(
"n_updates should be a positive non zero integer, but got "
f"n_updates: {n_updates} in SCAFFOLD aggregator"
)
if self._fds is None:
raise FedbiomedAggregatorError(
"Federated Dataset not provided, but needed for Scaffold. Please use setter `set_fds()`."
)
return True
create_aggregator_args
create_aggregator_args(global_model, node_ids)
Return correction states that are to be sent to the nodes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
global_model | Dict[str, Union[Tensor, ndarray]] | parameters of the global model, formatted as a dict mapping weight tensors to their names. | required |
node_ids | Collection[str] | identifiers of the nodes that are to receive messages. | required |
Returns:
Type | Description |
---|---|
Dict[str, Dict[str, Any]] | Aggregator arguments to share with the nodes for the next round |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def create_aggregator_args(
self,
global_model: Dict[str, Union[torch.Tensor, np.ndarray]],
node_ids: Collection[str]
) -> Dict[str, Dict[str, Any]]:
"""Return correction states that are to be sent to the nodes.
Args:
global_model: parameters of the global model, formatted as a dict
mapping weight tensors to their names.
node_ids: identifiers of the nodes that are to receive messages.
Returns:
Aggregator arguments to share with the nodes for the next round
"""
# Optionally initialize states, and verify that nodes are known.
if not self.nodes_deltas:
self.init_correction_states(global_model)
if not set(node_ids).issubset(self._fds.node_ids()):
raise FedbiomedAggregatorError(
"Scaffold cannot create aggregator args for nodes that are not"
"covered by its attached FederatedDataset."
)
aggregator_dat = {}
for node_id in node_ids:
# If a node was late-added to the FederatedDataset, create states.
if node_id not in self.nodes_deltas:
zeros = {key: initialize(val)[1] for key, val in self.global_state.items()}
self.nodes_deltas[node_id] = zeros
self.nodes_states[node_id] = copy.deepcopy(zeros)
# Add information for the current node to the message dicts.
aggregator_dat[node_id] = {
'aggregator_name': self.aggregator_name,
'aggregator_correction': self.nodes_deltas[node_id]
}
return aggregator_dat
init_correction_states
init_correction_states(global_model)
Initialize Scaffold state variables.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
global_model | Dict[str, Union[Tensor, ndarray]] | parameters of the global model, formatted as a dict mapping weight tensors to their names. | required |
Raises:
Type | Description |
---|---|
FedbiomedAggregatorError | if no FederatedDataset is attached to this aggregator. |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def init_correction_states(
self,
global_model: Dict[str, Union[torch.Tensor, np.ndarray]],
) -> None:
"""Initialize Scaffold state variables.
Args:
global_model: parameters of the global model, formatted as a dict
mapping weight tensors to their names.
Raises:
FedbiomedAggregatorError: if no FederatedDataset is attached to
this aggregator.
"""
# Gather node ids from the attached FederatedDataset.
if self._fds is None:
raise FedbiomedAggregatorError(
"Cannot initialize correction states: Scaffold aggregator does "
"not have a FederatedDataset attached."
)
node_ids = self._fds.node_ids()
# Initialize nodes states with zero scalars, that will be summed into actual tensors.
init_params = {key: initialize(tensor)[1] for key, tensor in global_model.items()}
self.nodes_deltas = {node_id: copy.deepcopy(init_params) for node_id in node_ids}
self.nodes_states = copy.deepcopy(self.nodes_deltas)
self.global_state = init_params
load_state_breakpoint
load_state_breakpoint(state=None)
Source code in fedbiomed/researcher/aggregators/scaffold.py
def load_state_breakpoint(self, state: Dict[str, Any] = None):
super().load_state_breakpoint(state)
self.server_lr = self._aggregator_args['server_lr']
# loading global state
global_state_filename = self._aggregator_args['global_state_filename']
self.global_state = Serializer.load(global_state_filename)
for node_id in self._aggregator_args['nodes']:
self.nodes_deltas[node_id] = self._aggregator_args[node_id]['aggregator_correction']
self.nodes_states = copy.deepcopy(self.nodes_deltas)
save_state_breakpoint
save_state_breakpoint(breakpoint_path, global_model)
Source code in fedbiomed/researcher/aggregators/scaffold.py
def save_state_breakpoint(
self,
breakpoint_path: str,
global_model: Mapping[str, Union[torch.Tensor, np.ndarray]]
) -> Dict[str, Any]:
# adding aggregator parameters to the breakpoint that wont be sent to nodes
self._aggregator_args['server_lr'] = self.server_lr
# saving global state variable into a file
filename = os.path.join(breakpoint_path, f"global_state_{uuid.uuid4()}.mpk")
Serializer.dump(self.global_state, filename)
self._aggregator_args['global_state_filename'] = filename
self._aggregator_args["nodes"] = self._fds.node_ids()
# adding aggregator parameters that will be sent to nodes afterwards
return super().save_state_breakpoint(
breakpoint_path, global_model=global_model, node_ids=self._fds.node_ids()
)
set_nodes_learning_rate_after_training
set_nodes_learning_rate_after_training(training_plan, training_replies)
Gets back learning rate of optimizer from Node (if learning rate scheduler is used)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
training_plan | BaseTrainingPlan | training plan instance | required |
training_replies | Dict | training replies that must contain am | required |
Raises:
Type | Description |
---|---|
FedbiomedAggregatorError | raised when setting learning rate has been unsuccessful |
Returns:
Type | Description |
---|---|
Dict[str, List[float]] | Dict[str, List[float]]: dictionary mapping node_id and a list of float, as many as the number of layers contained in the model (in Pytroch, each layer can have a specific learning rate). |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def set_nodes_learning_rate_after_training(
self,
training_plan: BaseTrainingPlan,
training_replies: Dict,
) -> Dict[str, List[float]]:
"""Gets back learning rate of optimizer from Node (if learning rate scheduler is used)
Args:
training_plan: training plan instance
training_replies: training replies that must contain am `optimizer_args`
entry and a learning rate
Raises:
FedbiomedAggregatorError: raised when setting learning rate has been unsuccessful
Returns:
Dict[str, List[float]]: dictionary mapping node_id and a list of float, as many as
the number of layers contained in the model (in Pytroch, each layer can have a specific learning rate).
"""
n_model_layers = len(training_plan.get_model_params(
only_trainable=False,
exclude_buffers=True)
)
for node_id in self._fds.node_ids():
lrs: Dict[str, float] = {}
node = training_replies.get(node_id, None)
if node is not None:
lrs = training_replies[node_id]["optimizer_args"].get('lr')
if node is None or lrs is None:
# fall back to default value if no lr information was provided
lrs = training_plan.optimizer().get_learning_rate()
if len(lrs) != n_model_layers:
raise FedbiomedAggregatorError(
"Error when setting node learning rate for SCAFFOLD: cannot extract node learning rate."
)
self.nodes_lr[node_id] = lrs
return self.nodes_lr
set_training_plan_type
set_training_plan_type(training_plan_type)
Overrides set_training_plan_type
from parent class. Checks the training plan type, and if it is SKlearnTrainingPlan, raises an error. Otherwise, calls parent method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
training_plan_type | TrainingPlans | training_plan type | required |
Raises:
Type | Description |
---|---|
FedbiomedAggregatorError | raised if training_plan type has been set to SKLearn training plan |
Returns:
Name | Type | Description |
---|---|---|
TrainingPlans | TrainingPlans | training plan type |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def set_training_plan_type(self, training_plan_type: TrainingPlans) -> TrainingPlans:
"""
Overrides `set_training_plan_type` from parent class.
Checks the training plan type, and if it is SKlearnTrainingPlan,
raises an error. Otherwise, calls parent method.
Args:
training_plan_type (TrainingPlans): training_plan type
Raises:
FedbiomedAggregatorError: raised if training_plan type has been set to SKLearn training plan
Returns:
TrainingPlans: training plan type
"""
if training_plan_type == TrainingPlans.SkLearnTrainingPlan:
raise FedbiomedAggregatorError("Aggregator SCAFFOLD not implemented for SKlearn")
training_plan_type = super().set_training_plan_type(training_plan_type)
# TODO: trigger a warning if user is trying to use scaffold with something else than SGD
return training_plan_type
update_correction_states
update_correction_states(model_updates, n_updates)
Update all Scaffold state variables based on node-wise model updates.
Performed computations:
- Update participating nodes' state:
- c_i = delta_i + 1/(K*eta_i) * update_i
- Update the global state and all nodes' correction state:
- c = 1/N sum_{i=1}^n c_i
- delta_i = (c_i - c)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_updates | Dict[str, Dict[str, Union[ndarray, Tensor]]] | node-wise model weight updates. | required |
n_updates | int | number of local optimization steps. | required |
Source code in fedbiomed/researcher/aggregators/scaffold.py
def update_correction_states(
self,
model_updates: Dict[str, Dict[str, Union[np.ndarray, torch.Tensor]]],
n_updates: int,
) -> None:
"""Update all Scaffold state variables based on node-wise model updates.
Performed computations:
----------------------
- Update participating nodes' state:
* c_i = delta_i + 1/(K*eta_i) * update_i
- Update the global state and all nodes' correction state:
* c = 1/N sum_{i=1}^n c_i
* delta_i = (c_i - c)
Args:
model_updates: node-wise model weight updates.
n_updates: number of local optimization steps.
"""
# Update the node-wise states for participating nodes:
# c_i^{t+1} = delta_i^t + (x^t - y_i^t) / (M * eta)
for node_id, updates in model_updates.items():
d_i = self.nodes_deltas[node_id]
for (key, val) in updates.items():
if self.nodes_lr[node_id].get(key) is not None:
self.nodes_states[node_id].update(
{
key: d_i[key] + val / (self.nodes_lr[node_id][key] * n_updates)
}
)
# Update the global state: c^{t+1} = average(c_i^{t+1})
for key in self.global_state:
self.global_state[key] = 0
for state in self.nodes_states.values():
if state.get(key) is not None:
self.global_state[key] = (
sum(state[key] for state in self.nodes_states.values())
/ len(self.nodes_states)
)
# Compute the new node-wise correction states:
# delta_i^{t+1} = c_i^{t+1} - c^{t+1}
self.nodes_deltas = {
node_id: {
key: val - self.global_state[key] for key, val in state.items()
}
for node_id, state in self.nodes_states.items()
}
Functions
federated_averaging
federated_averaging(model_params, weights)
Defines Federated Averaging (FedAvg) strategy for model aggregation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_params | List[Dict[str, Union[Tensor, ndarray]]] | list that contains nodes' model parameters; each model is stored as an OrderedDict (maps model layer name to the model weights) | required |
weights | List[float] | weights for performing weighted sum in FedAvg strategy (depending on the dataset size of each node). Items in the list must always sum up to 1 | required |
Returns:
Type | Description |
---|---|
Mapping[str, Union[Tensor, ndarray]] | Final model with aggregated layers, as an OrderedDict object. |
Source code in fedbiomed/researcher/aggregators/functional.py
def federated_averaging(model_params: List[Dict[str, Union[torch.Tensor, np.ndarray]]],
weights: List[float]) -> Mapping[str, Union[torch.Tensor, np.ndarray]]:
"""Defines Federated Averaging (FedAvg) strategy for model aggregation.
Args:
model_params: list that contains nodes' model parameters; each model is stored as an OrderedDict (maps
model layer name to the model weights)
weights: weights for performing weighted sum in FedAvg strategy (depending on the dataset size of each node).
Items in the list must always sum up to 1
Returns:
Final model with aggregated layers, as an OrderedDict object.
"""
assert len(model_params) > 0, 'An empty list of models was passed.'
assert len(weights) == len(model_params), 'List with number of observations must have ' \
'the same number of elements that list of models.'
# Compute proportions
proportions = [n_k / sum(weights) for n_k in weights]
return weighted_sum(model_params, proportions)
initialize
initialize(val)
Initialize tensor or array vector.
Source code in fedbiomed/researcher/aggregators/functional.py
def initialize(val: Union[torch.Tensor, np.ndarray]) -> Tuple[str, Union[torch.Tensor, np.ndarray]]:
"""Initialize tensor or array vector. """
if isinstance(val, torch.Tensor):
return 'tensor', torch.zeros_like(val).float()
if isinstance(val, (list, np.ndarray)):
val = np.array(val)
return 'array', np.zeros(val.shape, dtype = float)
weighted_sum
weighted_sum(model_params, proportions)
Performs weighted sum operation
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_params | List[Dict[str, Union[Tensor, ndarray]]] | list that contains nodes' model parameters; each model is stored as an OrderedDict (maps model layer name to the model weights) | required |
proportions | List[float] | weights of all items whithin model_params's list | required |
Returns:
Type | Description |
---|---|
Mapping[str, Union[Tensor, ndarray]] | Mapping[str, Union[torch.Tensor, np.ndarray]]: model resulting from the weighted sum operation |
Source code in fedbiomed/researcher/aggregators/functional.py
def weighted_sum(model_params: List[Dict[str, Union[torch.Tensor, np.ndarray]]],
proportions: List[float]) -> Mapping[str, Union[torch.Tensor, np.ndarray]]:
"""Performs weighted sum operation
Args:
model_params (List[Dict[str, Union[torch.Tensor, np.ndarray]]]): list that contains nodes'
model parameters; each model is stored as an OrderedDict (maps model layer name to the model weights)
proportions (List[float]): weights of all items whithin model_params's list
Returns:
Mapping[str, Union[torch.Tensor, np.ndarray]]: model resulting from the weighted sum
operation
"""
# Empty model parameter dictionary
avg_params = copy.deepcopy(model_params[0])
for key, val in avg_params.items():
(t, avg_params[key] ) = initialize(val)
if t == 'tensor':
for model, weight in zip(model_params, proportions):
for key in avg_params.keys():
avg_params[key] += weight * model[key]
if t == 'array':
for key in avg_params.keys():
matr = np.array([ d[key] for d in model_params ])
avg_params[key] = np.average(matr, weights=np.array(proportions), axis=0)
return avg_params