Provide a way to easily to manage training arguments.
Attributes
DPArgsValidator module-attribute
DPArgsValidator = SchemeValidator({'type': {'rules': [str, _validate_dp_type], 'required': True, 'default': 'central'}, 'sigma': {'rules': [float], 'required': True}, 'clip': {'rules': [float], 'required': True}})
Classes
TrainingArgs
TrainingArgs(ta=None, extra_scheme=None, only_required=True)
Provide a container to manage training arguments.
This class uses the Validator and SchemeValidator classes and provides a default scheme, which describes the arguments necessary to train/validate a TrainingPlan.
It also permits to extend the TrainingArgs then testing new features by supplying an extra_scheme at TrainingArgs instantiation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ta | Dict | dictionary describing the TrainingArgs scheme. if empty dict or None, a minimal instance of TrainingArgs will be initialized with default values for required keys | None |
extra_scheme | Dict | user provided scheme extension, which add new rules or update the scheme of the default training args. Warning: this is a dangerous feature, provided to developers, to ease the test of future Fed-Biomed features | None |
only_required | bool | if True, the object is initialized only with required values defined in the default_scheme (+ extra_scheme). If False, then all default values will also be returned (not only the required key/value pairs). | True |
Raises:
Type | Description |
---|---|
FedbiomedUserInputError | in case of bad value or bad extra_scheme |
Source code in fedbiomed/common/training_args.py
def __init__(self, ta: Dict = None, extra_scheme: Dict = None, only_required: bool = True):
"""
Create a TrainingArgs from a Dict with input validation.
Args:
ta: dictionary describing the TrainingArgs scheme.
if empty dict or None, a minimal instance of TrainingArgs
will be initialized with default values for required keys
extra_scheme: user provided scheme extension, which add new rules or
update the scheme of the default training args.
Warning: this is a dangerous feature, provided to
developers, to ease the test of future Fed-Biomed features
only_required: if True, the object is initialized only with required
values defined in the default_scheme (+ extra_scheme).
If False, then all default values will also be returned
(not only the required key/value pairs).
Raises:
FedbiomedUserInputError: in case of bad value or bad extra_scheme
"""
self._scheme = TrainingArgs.default_scheme()
if not isinstance(extra_scheme, dict):
extra_scheme = {}
for k in extra_scheme:
self._scheme[k] = extra_scheme[k]
try:
self._sc = SchemeValidator(self._scheme)
except RuleError as e:
#
# internal error (invalid scheme)
msg = ErrorNumbers.FB414.value + f": {e}"
logger.critical(msg)
raise FedbiomedUserInputError(msg)
# scheme is validated from here
if ta is None:
ta = {}
try:
self._ta = self._sc.populate_with_defaults(ta, only_required=only_required)
except ValidatorError as e:
# scheme has required keys without defined default value
msg = ErrorNumbers.FB414.value + f": {e}"
logger.critical(msg)
raise FedbiomedUserInputError(msg)
try:
self._sc.validate(self._ta)
except ValidateError as e:
# transform to a Fed-BioMed error
msg = ErrorNumbers.FB414.value + f": {e}"
logger.critical(msg)
raise FedbiomedUserInputError(msg)
# Validate DP arguments if it is existing in training arguments
if self._ta["dp_args"] is not None:
try:
self._ta["dp_args"] = DPArgsValidator.populate_with_defaults(self._ta["dp_args"], only_required=False)
DPArgsValidator.validate(self._ta["dp_args"])
except ValidateError as e:
msg = f"{ErrorNumbers.FB414.value}: {e}"
logger.critical(msg)
raise FedbiomedUserInputError(msg)
Functions
default_scheme classmethod
default_scheme()
Returns the default (base) scheme for TrainingArgs.
A summary of the semantics of each argument is given below. Please refer to the source code of this function for additional information on typing and constraints.
argument | meaning |
---|---|
optimizer_args | supplemental arguments for initializing the optimizer |
loader_args | supplemental arguments passed to the data loader |
epochs | the number of epochs performed during local training on each node |
num_updates | the number of model updates performed during local training on each node. Supersedes epochs if both are specified |
use_gpu | toggle requesting the use of GPUs for local training on the node when available |
dry_run | perform a single model update for testing on each node and correctly handle GPU execution |
batch_maxnum | prematurely break after batch_maxnum model updates for each epoch (useful for testing) |
test_ratio | the proportion of validation samples to total number of samples in the dataset |
test_batch_size | batch size used for testing trained model wrt a set of metric |
test_on_local_updates | toggles validation after local training |
test_on_global_updates | toggles validation before local training |
test_metric | metric to be used for validation |
test_metric_args | supplemental arguments for the validation metric |
log_interval | output a training logging entry every log_interval model updates |
fedprox_mu | set the value of mu and enable FedProx correction |
dp_args | arguments for Differential Privacy |
share_persistent_buffers | toggle whether nodes share the full state_dict (when True) or only trainable parameters (False) in a TorchTrainingPlan |
random_seed | set random seed at the beginning of each round |
Source code in fedbiomed/common/training_args.py
@classmethod
def default_scheme(cls) -> Dict:
"""
Returns the default (base) scheme for TrainingArgs.
A summary of the semantics of each argument is given below. Please refer to the source code of this function
for additional information on typing and constraints.
| argument | meaning |
| -------- | ------- |
| optimizer_args | supplemental arguments for initializing the optimizer |
| loader_args | supplemental arguments passed to the data loader |
| epochs | the number of epochs performed during local training on each node |
| num_updates | the number of model updates performed during local training on each node. Supersedes epochs if both are specified |
| use_gpu | toggle requesting the use of GPUs for local training on the node when available |
| dry_run | perform a single model update for testing on each node and correctly handle GPU execution |
| batch_maxnum | prematurely break after batch_maxnum model updates for each epoch (useful for testing) |
| test_ratio | the proportion of validation samples to total number of samples in the dataset |
| test_batch_size | batch size used for testing trained model wrt a set of metric |
| test_on_local_updates | toggles validation after local training |
| test_on_global_updates | toggles validation before local training |
| test_metric | metric to be used for validation |
| test_metric_args | supplemental arguments for the validation metric |
| log_interval | output a training logging entry every log_interval model updates |
| fedprox_mu | set the value of mu and enable FedProx correction |
| dp_args | arguments for Differential Privacy |
| share_persistent_buffers | toggle whether nodes share the full state_dict (when True) or only trainable parameters (False) in a TorchTrainingPlan |
| random_seed | set random seed at the beginning of each round |
"""
return {
"optimizer_args": {
"rules": [dict], "required": True, "default": {}
},
"loader_args": {
"rules": [dict], "required": True, "default": {}
},
"epochs": {
"rules": [cls._nonnegative_integer_value_validator_hook('epochs')], "required": True, "default": None
},
"num_updates": {
"rules": [cls._nonnegative_integer_value_validator_hook('num_updates')],
"required": True, "default": None
},
"dry_run": {
"rules": [bool], "required": True, "default": False
},
"batch_maxnum": {
"rules": [cls._nonnegative_integer_value_validator_hook('batch_maxnum')],
"required": True, "default": None
},
"test_ratio": {
"rules": [float, cls._test_ratio_hook], "required": False, "default": 0.0
},
"test_batch_size": {
"rules": [cls.optional_type(typespec=int, argname='test_batch_size')],
"required": False,
"default": 0
},
"test_on_local_updates": {
"rules": [bool], "required": False, "default": False
},
"test_on_global_updates": {
"rules": [bool], "required": False, "default": False
},
"test_metric": {
"rules": [cls._metric_validation_hook], "required": False, "default": None
},
"test_metric_args": {
"rules": [dict], "required": False, "default": {}
},
"log_interval": {
"rules": [int], "required": False, "default": 10
},
"fedprox_mu": {
"rules": [cls._fedprox_mu_validator], 'required': False, "default": None
},
"use_gpu": {
"rules": [bool], 'required': False, "default": False
},
"dp_args": {
"rules": [cls._validate_dp_args], "required": True, "default": None
},
"share_persistent_buffers": {
"rules": [bool], "required": False, "default": True
},
"random_seed": {
"rules": [cls.optional_type(typespec=int, argname='random_seed')], "required": True, "default": None
}
}
default_value
default_value(key)
Returns the default value for the key.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key | str | key | required |
Returns:
Name | Type | Description |
---|---|---|
value | Any | the default value associated to the key |
Raises:
Type | Description |
---|---|
FedbiomedUserInputError | in case of problem (invalid key or value) |
Source code in fedbiomed/common/training_args.py
def default_value(self, key: str) -> Any:
"""
Returns the default value for the key.
Args:
key: key
Returns:
value: the default value associated to the key
Raises:
FedbiomedUserInputError: in case of problem (invalid key or value)
"""
if key in self._sc.scheme():
if "default" in self._sc.scheme()[key]:
return deepcopy(self._sc.scheme()[key]["default"])
else:
msg = ErrorNumbers.FB410.value + \
f"no default value defined for key: {key}"
logger.critical(msg)
raise FedbiomedUserInputError(msg)
else:
msg = ErrorNumbers.FB410.value + \
f"no such key: {key}"
logger.critical(msg)
raise FedbiomedUserInputError(msg)
dict
dict()
Returns a copy of the training_args as a dictionary.
Source code in fedbiomed/common/training_args.py
def dict(self) -> dict:
"""Returns a copy of the training_args as a dictionary."""
ta = deepcopy(self._ta)
return ta
dp_arguments
dp_arguments()
Extracts the arguments for differential privacy
Returns:
Type | Description |
---|---|
Contains differential privacy arguments |
Source code in fedbiomed/common/training_args.py
def dp_arguments(self):
"""Extracts the arguments for differential privacy
Returns:
Contains differential privacy arguments
"""
return self["dp_args"]
get
get(key, default=None)
Mimics the get() method of dict, provided for backward compatibility.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key | str | a key for retrieving data fro the dictionary | required |
default | Any | default value to return if key does not belong to dictionary | None |
Source code in fedbiomed/common/training_args.py
def get(self, key: str, default: Any = None) -> Any:
"""Mimics the get() method of dict, provided for backward compatibility.
Args:
key: a key for retrieving data fro the dictionary
default: default value to return if key does not belong to dictionary
"""
try:
return deepcopy(self._ta[key])
except KeyError:
# TODO: test if provided default value is compliant with the scheme
return default
get_state_breakpoint
get_state_breakpoint()
Returns JSON serializable dict as state for breakpoints
Source code in fedbiomed/common/training_args.py
def get_state_breakpoint(self):
"""Returns JSON serializable dict as state for breakpoints"""
# TODO: This method is a temporary solution for JSON
# serialize error during breakpoint save operation
args = self.dict()
test_metric = args.get('test_metric')
if test_metric and isinstance(test_metric, MetricTypes):
args['test_metric'] = test_metric.name
return args
load_state_breakpoint classmethod
load_state_breakpoint(state)
Loads training arguments state
Source code in fedbiomed/common/training_args.py
@classmethod
def load_state_breakpoint(cls, state: Dict) -> 'TrainingArgs':
"""Loads training arguments state"""
if state.get('test_metric'):
state.update(
{'test_metric': MetricTypes.get_metric_type_by_name(
state.get('test_metric'))})
return cls(state)
loader_arguments
loader_arguments()
Extracts data loader arguments
Returns:
Type | Description |
---|---|
Dict | The dictionary of arguments for dataloader |
Source code in fedbiomed/common/training_args.py
def loader_arguments(self) -> Dict:
""" Extracts data loader arguments
Returns:
The dictionary of arguments for dataloader
"""
return self["loader_args"]
optimizer_arguments
optimizer_arguments()
Source code in fedbiomed/common/training_args.py
def optimizer_arguments(self) -> Dict:
return self["optimizer_args"]
optional_type staticmethod
optional_type(typespec, argname)
Utility factory function to generate functions that check for an optional type(s).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
typespec | Union[Type, Tuple[Type, ...]] | type specification which will be passed to the | required |
argname | str | the name of the training argument for outputting meaningful error messages | required |
Returns:
Name | Type | Description |
---|---|---|
type_check | a callable that takes a single argument and checks whether it is either None or the required type(s) |
Source code in fedbiomed/common/training_args.py
@staticmethod
def optional_type(typespec: Union[Type, Tuple[Type, ...]], argname: str):
"""Utility factory function to generate functions that check for an optional type(s).
Args:
typespec: type specification which will be passed to the `isinstance` function
argname: the name of the training argument for outputting meaningful error messages
Returns:
type_check: a callable that takes a single argument and checks whether it is either None
or the required type(s)
"""
@validator_decorator
def type_check(v):
if v is not None and not isinstance(v, typespec):
return False, f"Invalid type: {argname} must be {typespec} or None"
return True
return type_check
pure_training_arguments
pure_training_arguments()
Extracts the arguments that are only necessary for training_routine
Returns:
Type | Description |
---|---|
Contains training argument for training routine |
Source code in fedbiomed/common/training_args.py
def pure_training_arguments(self):
""" Extracts the arguments that are only necessary for training_routine
Returns:
Contains training argument for training routine
"""
keys = ["batch_maxnum",
"fedprox_mu",
"log_interval",
"share_persistent_buffers",
"dry_run",
"epochs",
"use_gpu",
"num_updates"]
return self._extract_args(keys)
scheme
scheme()
Returns the scheme of a TrainingArgs instance.
The scheme is not necessarily the default_scheme (returned by TrainingArgs.default_scheme().
Returns:
Name | Type | Description |
---|---|---|
scheme | Dict | the current scheme used for validation |
Source code in fedbiomed/common/training_args.py
def scheme(self) -> Dict:
"""
Returns the scheme of a TrainingArgs instance.
The scheme is not necessarily the default_scheme (returned by TrainingArgs.default_scheme().
Returns:
scheme: the current scheme used for validation
"""
return deepcopy(self._scheme)
testing_arguments
testing_arguments()
Extract testing arguments from training arguments
Returns:
Type | Description |
---|---|
Dict | Testing arguments as dictionary |
Source code in fedbiomed/common/training_args.py
def testing_arguments(self) -> Dict:
""" Extract testing arguments from training arguments
Returns:
Testing arguments as dictionary
"""
keys = ['test_ratio', 'test_on_local_updates', 'test_on_global_updates',
'test_metric', 'test_metric_args', 'test_batch_size']
return self._extract_args(keys)
update
update(values)
Update multiple keys of the training arguments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
values | Dict | a dictionnary of (key, value) to validate/update | required |
Returns:
Type | Description |
---|---|
TypeVar(TrainingArgs) | the object itself after modification |
Raises:
Type | Description |
---|---|
FedbiomedUserInputError | in case of bad key or value in values |
Source code in fedbiomed/common/training_args.py
def update(self, values: Dict) -> TypeVar("TrainingArgs"):
"""
Update multiple keys of the training arguments.
Args:
values: a dictionnary of (key, value) to validate/update
Returns:
the object itself after modification
Raises:
FedbiomedUserInputError: in case of bad key or value in values
"""
for k in values:
self.__setitem__(k, values[k])
return self