How to Create Your Custom PyTorch Training Plan¶
Fed-BioMed allows you to perform model training without changing your PyTorch training plan class completely. Integrating your PyToch model to Fed-BioMed only requires to add extra attributes and methods to train your model based on a federated approach. In this tutorial, you will learn how to write/define your TrainingPlan
(wrapping your model) in Fed-BioMed for PyTorch framework.
Note: Before starting this tutorial we highly recommend you to follow the previous tutorials to understand the basics of Fed-BioMed.
In this tutorial, we will be using Celaba (CelebaFaces) dataset to train the model. You can see details of the dataset here. In the following sections, you will have the instructions for downloading and configuring Celeba dataset for Fed-BioMed framework.
1. Fed-BioMed Training Plan¶
In this section, you will learn how to write your custom training plan.
What is Training Plan?¶
The training plan is the class where all the methods and attributes are defined to train your model on the nodes. Each training plan should inherit the base training plan class of the belonging ML framework that is provided by Fed-BioMed. For more details, you can visit documentation for training plan. The following code snippet shows a basic training plan that can be defined in Fed-BioMed for PyTorch framework.
from fedbiomed.common.training_plans import TorchTrainingPlan
class CustomTrainingPlan(TorchTrainingPlan):
def init_model(self, model_args):
# Define here your model
# ...
return
def init_dependencies(self):
# Add here the dependencies / third party libraries to be loaded
#...
return
def init_optimizer(self, optimizer_args):
# Define here your optimizer
#...
return
def training_data(self, batch_size = 48):
# Define here how data are processed before feeding it to the model
# ...
return
def training_step(self, data, target):
# Define here the loss function
# ...
return
init_model
Method of Training Plan¶
init_model
method of the training plan is where you initialize your neural network module as in classical PyTorch model class. The network should be defined inside the training plan class and init_model
should instantiate this network (Module
), and return it.
In this tutorial, we will be training a classification model for CelebA image dataset that will be able to predict whether the given face is smiling.
def init_model(self, model_args: dict = {}):
return self.Net(model_args)
class Net(nn.Module):
def __init__(model_args):
super().__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(3, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 32, 3, 1)
self.conv3 = nn.Conv2d(32, 32, 3, 1)
self.conv4 = nn.Conv2d(32, 32, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
# Classifier
self.fc1 = nn.Linear(3168, 128)
self.fc2 = nn.Linear(128, 2)
def forward(self, x):
x = self.conv1(x)
x = F.max_pool2d(x, 2)
x = F.relu(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = F.relu(x)
x = self.conv3(x)
x = F.max_pool2d(x, 2)
x = F.relu(x)
x = self.conv4(x)
x = F.max_pool2d(x, 2)
x = F.relu(x)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
init_dependencies
Method¶
Next, you should define the init_dependencies
to declare the modules that are used in the training plan. The modules should be supported by the Fed-BioMed.
def init_depedencies(self)
# Here we define the custom dependencies that will be needed by our custom Dataloader
deps = ["from torch.utils.data import Dataset, DataLoader",
"from torchvision import transforms",
"import pandas as pd",
"from PIL import Image",
"import os",
"import numpy as np"]
return deps
init_optimizer
Method¶
To optimize your model, you will need an optimizer. This is where init_optimizer
method comes into play. In this method, you may change the optimizer you want to use, add a PyTorch learning rate Scheduler
or provide your custom optimizer.
init_optimizer
takes optimizer_args
as argument, an entry from training_args
(more details later) which is a dictionary containing parameters that may be needed for optimizer initalization (such as learning rate, Adagrad
weights decay, Adam
beta parameters, ...). init_optimizer
method should return the initialized optimizer
, that will be used to optimize model.
Defining an optimizer
in Fed-BioMed is pretty similar to PyTorch, as shown in the example below (using PyTorch's SGD
optimizer):
def init_optimizer(self, optimizer_args):
return torch.optim.SGD(self.model().parameters(), lr=optimizer_args['lr'])
By default (if this method is not specified in the TrainingPlan
), model will be optimized using default Adam
optimizer.
training_data() and Custom Dataset
¶
training_data
is a method where the data is loaded for training on the node side. During each round of training, each node that particapates federated training builds the model, loads the dataset using the method training_data
, and performs the training_step
by passing loaded dataset.
The dataset that we will be using in this tutorial is a image dataset. Therefore, your custom PyTorch Dataset
should be be able to load images by given index . Please see the details of custom PyTorch datasets.
class CelebaDataset(Dataset):
"""Custom Dataset for loading CelebA face images"""
def __init__(self, txt_path, img_dir, transform=None):
# Read the csv file that includes classes for each image
df = pd.read_csv(txt_path, sep="\t", index_col=0)
self.img_dir = img_dir
self.txt_path = txt_path
self.img_names = df.index.values
self.y = df['Smiling'].values
self.transform = transform
def __getitem__(self, index):
img = np.asarray(Image.open(os.path.join(self.img_dir, self.img_names[index])))
img = transforms.ToTensor()(img)
label = self.y[index]
return img, label
def __len__(self):
return self.y.shape[0]
Now, you need to define a training_data
method that will create a Fed-BioMed DataManager using custom CelebaDataset
class.
def training_data(self):
# The training_data creates the dataset and returns DataManager to be used for training in the general class Torchnn of Fed-BioMed
dataset = self.CelebaDataset(self.dataset_path + "/target.csv", self.dataset_path + "/data/")
loader_arguments = { 'shuffle': True}
return DataManager(dataset, **loader_arguments)
training_step()
¶
The last method that needs to be defined is the training_step
. This method is responsible for executing the forward method and calculating the loss value for the backward process of the network. To access the forward
method of the torch.nn.Module
that is defined in the init_model
, the getter method model()
of training plan class should be used.
def training_step(self, data, target):
output = self.model().forward(data)
loss = torch.nn.functional.nll_loss(output, target)
return loss
You are now ready to create your training plan class. All you need to do is to locate every method that has been explained in the previous sections in your traning plan class. In the next steps we will;
- download the CelebA dataset and deploy it on the nodes
- define our complete training
- create an experiment and run it
- evaluate our model using a testing dataset
2.Configuring Nodes¶
We will be working with CelebA (CelebFaces) dataset. Therefore, please visit here and download the files img/img_align_celeba.zip
and Anno/list_attr_celeba.txt
. After the download operation is completed;
- Please go to
./notebooks/data/Celeba
in Fed-BioMed project. - Create
Celeba_raw/raw
directory and copy thelist_attr_celeba.txt
file. - Extract the zip file
img_align_celeba.zip
Your folder should be same as the tree below;
Celeba
README.md
create_node_data.py
.gitignore
Celeba_raw
raw
list_attr_celeba.txt
img_align_celeba.zip
img_align_celeba
lots of images
The dataset has to be processed and split to create three distinct datasets for Node 1, Node 2, and Node 3. You can do it easily by running the following script in your notebook. If you are working in a different directory than the fedbiomed/notebooks
, please make sure that you define the correct paths in the following example.
Running the following scripts might take some time, please be patient.
import os
import pandas as pd
import shutil
from fedbiomed.researcher.environ import environ
# Celeba folder
parent_dir = os.path.join(environ["ROOT_DIR"], "notebooks", "data", "Celeba")
celeba_raw_folder = os.path.join("Celeba_raw", "raw")
img_dir = os.path.join(parent_dir, celeba_raw_folder, 'img_align_celeba') + os.sep
out_dir = os.path.join(parent_dir, "celeba_preprocessed")
# Read attribute CSV and only load Smilling column
df = pd.read_csv(os.path.join(parent_dir, celeba_raw_folder, 'list_attr_celeba.txt'),
sep="\s+", skiprows=1, usecols=['Smiling'])
# data is on the form : 1 if the person is smiling, -1 otherwise. we set all -1 to 0 for the model to train faster
df.loc[df['Smiling'] == -1, 'Smiling'] = 0
# Split csv in 3 part
length = len(df)
data_node_1 = df.iloc[:int(length/3)]
data_node_2 = df.iloc[int(length/3):int(length/3) * 2]
data_node_3 = df.iloc[int(length/3) * 2:]
# Create folder for each node
if not os.path.exists(os.path.join(out_dir, "data_node_1")):
os.makedirs(os.path.join(out_dir, "data_node_1", "data"))
if not os.path.exists(os.path.join(out_dir, "data_node_2")):
os.makedirs(os.path.join(out_dir, "data_node_2", "data"))
if not os.path.exists(os.path.join(out_dir, "data_node_3")):
os.makedirs(os.path.join(out_dir, "data_node_3", "data"))
# Save each node's target CSV to the corect folder
data_node_1.to_csv(os.path.join(out_dir, 'data_node_1', 'target.csv'), sep='\t')
data_node_2.to_csv(os.path.join(out_dir, 'data_node_2', 'target.csv'), sep='\t')
data_node_3.to_csv(os.path.join(out_dir, 'data_node_3', 'target.csv'), sep='\t')
# Copy all images of each node in the correct folder
for im in data_node_1.index:
shutil.copy(img_dir+im, os.path.join(out_dir,"data_node_1", "data", im))
print("data for node 1 succesfully created")
for im in data_node_2.index:
shutil.copy(img_dir+im, os.path.join(out_dir, "data_node_2", "data", im))
print("data for node 2 succesfully created")
for im in data_node_3.index:
shutil.copy(img_dir+im, os.path.join(out_dir, "data_node_3", "data", im))
print("data for node 3 succesfully created")
Now if you go to the ${FEDBIOMED_DIR}/notebooks/data/Celaba
directory you can see the folder called celeba_preprocessed
. There will be three different folders that contain an image dataset for 3 nodes. The next step will be configuring the nodes and deplying the datasets. In the next steps, we will be configuring only two nodes. The dataset for the third node is going to be used for the testing.
Create 2 nodes for training :
${FEDBIOMED_DIR}/scripts/fedbiomed_run node --config node1.ini start
${FEDBIOMED_DIR}/scripts/fedbiomed_run node --config node2.ini start
Add data to each node :
${FEDBIOMED_DIR}/scripts/fedbiomed_run node --config node1.ini dataset add
${FEDBIOMED_DIR}/scripts/fedbiomed_run node --config node2.ini dataset add
Note: ${FEDBIOMED_DIR}
is a path relative to based directory of the cloned Fed-BioMed repository. You can set it by running command export FEDBIOMED_DIR=/path/to/fedbiomed
. This is not required for Fed-BioMed to work but enables you to run the tutorials more easily.
2.1. Configuration Steps¶
It is necessary to previously configure at least a node:
${FEDBIOMED_DIR}/scripts/fedbiomed_run node --config (ini file) dataset add
- Select option
4
(images) to add an image dataset to the node - Add a name and the tag for the dataset (tag should contain '#celeba' as it is the tag used for this training) and finally add the description
- Pick a data folder from the 3 generated datasets inside
data/Celeba/celeba_preprocessed
(eg:data_node_1
) - Data must have been added (if you get a warning saying that data must be unique is because it's been already added)
- Check that your data has been added by executing
${FEDBIOMED_DIR}/scripts/fedbiomed_run node --config (ini file) dataset list
- Run the node using
${FEDBIOMED_DIR}/scripts/fedbiomed_run node --config <ini file> start
. Wait until you getStarting task manager
. it means you are online.
After the steps above are completed, you will be ready to train your classification model on two different nodes.
3. Defining Custom PyTorch Model and Training Plan¶
Next step is to create our Net
class based on the methods that have been explained in the previous sections. This class is part of the training plan that will be passed to the Experiment. Afterwards, the nodes will receive the training plan and perform the training by retrieving training data and passing it to the training_step
.
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset
from fedbiomed.common.data import DataManager
import pandas as pd
import numpy as np
from PIL import Image
import os
class CelebaTrainingPlan(TorchTrainingPlan):
# Defines model
def init_model(self):
model = self.Net()
return model
# Here we define the custom dependencies that will be needed by our custom Dataloader
def init_dependencies(self):
deps = ["from torch.utils.data import Dataset",
"from torchvision import transforms",
"import pandas as pd",
"from PIL import Image",
"import os",
"import numpy as np"]
return deps
# Torch modules class
class Net(nn.Module):
def __init__(self):
super().__init__()
#convolution layers
self.conv1 = nn.Conv2d(3, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 32, 3, 1)
self.conv3 = nn.Conv2d(32, 32, 3, 1)
self.conv4 = nn.Conv2d(32, 32, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
# classifier
self.fc1 = nn.Linear(3168, 128)
self.fc2 = nn.Linear(128, 2)
def forward(self, x):
x = self.conv1(x)
x = F.max_pool2d(x, 2)
x = F.relu(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = F.relu(x)
x = self.conv3(x)
x = F.max_pool2d(x, 2)
x = F.relu(x)
x = self.conv4(x)
x = F.max_pool2d(x, 2)
x = F.relu(x)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
class CelebaDataset(Dataset):
"""Custom Dataset for loading CelebA face images"""
# we dont load the full data of the images, we retrieve the image with the get item.
# in our case, each image is 218*178 * 3colors. there is 67533 images. this take at leas 7G of ram
# loading images when needed takes more time during training but it wont impact the ram usage as much as loading everything
def __init__(self, txt_path, img_dir, transform=None):
df = pd.read_csv(txt_path, sep="\t", index_col=0)
self.img_dir = img_dir
self.txt_path = txt_path
self.img_names = df.index.values
self.y = df['Smiling'].values
self.transform = transform
print("celeba dataset finished")
def __getitem__(self, index):
img = np.asarray(Image.open(os.path.join(self.img_dir,
self.img_names[index])))
img = transforms.ToTensor()(img)
label = self.y[index]
return img, label
def __len__(self):
return self.y.shape[0]
# The training_data creates the Dataloader to be used for training in the
# general class Torchnn of fedbiomed
def training_data(self):
dataset = self.CelebaDataset(self.dataset_path + "/target.csv", self.dataset_path + "/data/")
loader_arguments = { 'shuffle': True}
return DataManager(dataset, **loader_arguments)
# This function must return the loss to backward it
def training_step(self, data, target):
output = self.model().forward(data)
loss = torch.nn.functional.nll_loss(output, target)
return loss
This group of arguments corresponds respectively to:
model_args
: a dictionary with the arguments related to the model (e.g. number of layers, features, etc.). This will be passed to the model class on the node-side.training_args
: a dictionary containing the arguments for the training routine (e.g. batch size, learning rate, epochs, etc.). This will be passed to the routine on the node-side.
Note: Typos and/or lack of positional (required) arguments might raise an error.
training_args = {
'loader_args': { 'batch_size': 32, },
'optimizer_args': {
'lr': 1e-3
},
'epochs': 1,
'dry_run': False,
'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}
4. Training Federated Model¶
To provide training orchestration over two nodes we need to define an experiment which:
- searches nodes serving data for the
tags
, - defines the local training on nodes with the training plan saved in
training_plan_path
, and federates all local updates at each round withaggregator
- runs training for
round_limit
.
You can visit user guide to know much more about experiment.
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage
tags = ['#celeba']
rounds = 3
exp = Experiment(tags=tags,
training_plan_class=CelebaTrainingPlan,
training_args=training_args,
round_limit=rounds,
aggregator=FedAverage(),
node_selection_strategy=None)
Let's start the experiment.
By default, this function doesn't stop until all the round_limit
rounds are done for all the nodes. While the experiment runs you can open the terminals where you have started the nodes and see the training progress. However, the loss values obtained from each node during the training will be printed as output in real time. Since we are working on an image dataset, training might take some time.
exp.run()
Save trained model to file
exp.training_plan().export_model('./trained_model')
Loading Training Parameters¶
After all the rounds have been completed, you retrieve the aggregated parameters from the last round and load them.
fed_model = exp.training_plan().model()
fed_model.load_state_dict(exp.aggregated_params()[rounds - 1]['params'])
5. Testing Federated Model¶
We will define a testing routine to extract the accuracy metrics on the testing dataset. We will use the dataset that has been extracted into data_node_3
.
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from PIL import Image
import os
def testing_Accuracy(model, data_loader):
model.eval()
test_loss = 0
correct = 0
device = "cpu"
correct = 0
loader_size = len(data_loader)
with torch.no_grad():
for idx, (data, target) in enumerate(data_loader):
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
#only uses 10% of the dataset, results are similar but faster
if idx >= loader_size / 10:
pass
break
pred = output.argmax(dim=1, keepdim=True)
test_loss /= len(data_loader.dataset)
accuracy = 100* correct/(data_loader.batch_size * idx)
return(test_loss, accuracy)
We also need to define a custom Dataset class for the test dataset in order to load it using PyTorch's DataLoader
. This will be the same class that has been already defined in the training plan.
from fedbiomed.researcher.environ import environ
from torch.utils.data import DataLoader
test_dataset_path = os.path.join(environ["ROOT_DIR"],
"notebooks",
"data",
"Celeba",
"celeba_preprocessed",
"data_node_3")
class CelebaDataset(Dataset):
"""Custom Dataset for loading CelebA face images"""
def __init__(self, txt_path, img_dir, transform=None):
df = pd.read_csv(txt_path, sep="\t", index_col=0)
self.img_dir = img_dir
self.txt_path = txt_path
self.img_names = df.index.values
self.y = df['Smiling'].values
self.transform = transform
print("celeba dataset finished")
def __getitem__(self, index):
img = np.asarray(Image.open(os.path.join(self.img_dir,
self.img_names[index])))
img = transforms.ToTensor()(img)
label = self.y[index]
return img, label
def __len__(self):
return self.y.shape[0]
dataset = CelebaDataset(test_dataset_path + "/target.csv", test_dataset_path + "/data/")
train_kwargs = { 'shuffle': True}
data_loader = DataLoader(dataset, **train_kwargs)
acc_federated = testing_Accuracy(fed_model, data_loader)
acc_federated[1]
Conclusions¶
In this tutorial, running a custom model on Fed-BioMed (by wrapping it in a custom training plan) for the PyTorch framework has been explained. Because the examples are designed for the development environment, we have been running nodes in the same host machine. In production, the nodes that you need to use to train your model will serve in remote servers. Please check out how to deploy Nodes in a production environment.