Brain Segmentation¶
This tutorial will show how to use Fed-BioMed to perform image segmentation on 3D medical MRI images of brains, using the publicly available IXI dataset. It uses a 3D U-Net model for the segmentation, trained on data from 3 separate centers.
Here we display a very complex case, using advanced Fed-BioMed functionalities such as:
- loading a
MedicalFolderDataset
- implementing a custom Node Selection Strategy
- setting a non-default Optimizer
- monitoring training loss with Tensorboard
This tutorial is based on TorchIO's tutorial.
Automatic download and wrangling for the impatient¶
If you're not interested in the details, you may simply execute the download_and_split_ixi.py
script provided by us, as explained below
mkdir -p ${FEDBIOMED_DIR}/notebooks/data
pip install tqdm
download_and_split_ixi.py -f ${FEDBIOMED_DIR}
After successfully running the command, follow the instructions printed to add the datasets and run the nodes. The tag used for this experiment is ixi-train
.
Details about data preparation¶
If you just want to run the notebook, you may skip this section and skip to Define a new strategy
.
First, download the IXI dataset from the Mendeley archive.
In this tutorial we are going to use the MedicalFolderDataset
class provided by the Fed-BioMed library to load medical images in NIFTI format. Using this dataset class for image segmentation problems guarantees maximum compatibility with the rest of the Fed-BioMed functionalities and features.
Folder structure for MedicalFolderDataset¶
The MedicalFolderDataset
is heavily inspired by PyTorch's ImageFolder
Dataset, and requires you to manually prepare the image folders in order to respect a precise structure. The format assumes that you are dealing with imaging data, possibly acquired through multiple modalities, for different study subjects. Hence, you should provide one folder per subject, containing multiple subfolders for each image acquisition modality. Optionally, you may provide a csv
file containing additional tabular data associated with each subject. This file is typically used for demographics data, and by default is called participants.csv
.
_ root-folder
|_ participants.csv
|_ subject-1
| |_ modality-1
| |_ modality-2
|_ subject-2
| |_ modality-1
| |_ modality-2
|_ subject-3
| |_ modality-1
. .
. .
. .
Folder structure for this tutorial¶
In the specific case of this tutorial, we encourage you to further divide your images into additional subfolders, according to two criteria: the hospital that generated the data (there are three: Guys, HH and IOP) and a random train/holdout split.
!!! info Note that each subject's folder will have a name with the following structure: IXI<SUBJECT_ID>-<HOSPITAL>-<RANDOM_ID>
, for example IXI002-Guys-0828
. In conclusion, combining the splits above with the structure required by the MedicalFolderDataset
, your folder tree should look like this:
_root-folder
|_ Guys
| |_ train
| | |_ participants.csv
| | |_ IXI002-Guys-0828
| | | |_ T1 <-- T1 is the first imaging modality
| | | |_ T2
| | | |_ label
| | |_ IXI022-Guys-0701
| | | |_ T1
| | | |_ T2
. . .
. . .
. . .
| |_ holdout
| | |_ participants.csv
| | |_ IXI004-Guys-0321
| | | |_ T1
| | | |_ T2
| | | |_ label
| | | |_ T2
. . .
. . .
. . .
|_ HH
| |_ train
. . .
. . .
. . .
| |_ holdout
. . .
. . .
. . .
|_ IOP
. . .
. . .
Add the IXI dataset to the federated nodes¶
For each of the three hospitals, create a federated node and add the corresponding train dataset by selecting the medical-folder
data type, and inputting ixi-train
as the tag. Then start the nodes.
Dataset for demograhics of the subjects
After selecting the folder that contains the patients for training the CLI will ask for CSV file where demographics of the patient are stored. These CSV files are named as `participants.csv`, and you can find these CSV files in the folder where the subject folders are located e.g `Guys/train/participant.csv`.
If you don't know how to add datasets to a node, or start a node, please read our user guide or follow the basic tutorial.
Create a Training Plan¶
We create a training plan that incorporates the UNet model. We rely on the unet package for simplicity. Please refer to the original package for more details about UNet: Pérez-García, Fernando. (2020). fepegar/unet: PyTorch implementation of 2D and 3D U-Net (v0.7.5). Zenodo. https://doi.org/10.5281/zenodo.3697931
Define the model via the init_model
function¶
The init_model
function must return a UNet instance. Please refer to the TrainingPlan documentation for more details.
Define the loss function via the training_step
function¶
Loss function is computed based on the Dice Loss.
Carole H Sudre, Wenqi Li, Tom Vercauteren, Sebastien Ourselin, and M Jorge Cardoso. Generalised dice overlap as a deep learning loss function for highly unbalanced segmentations. In Deep learning in medical image analysis and multimodal learning for clinical decision support, pages 240–248. Springer, 2017.
Define data loading and transformations via the training_data
function¶
Within the training_data
function, we create an instance of MedicalFolderDataset
and pass it to Fed-BioMed's DataManager
class.
To preprocess images, we define the image transformations for the input images and the labels leveraging MONAI's transforms. Note that we also include the correct dependencies in the init_dependencies
function.
Additionally, we define a transformation for the demographics data contained in the associated csv
file. In order to be able to use information extracted from the demographics data as inputs to UNet, we must convert it to a torch.Tensor
object. To achieve this, we exploit the demographics_transform
argument of the MedicalFolderDataset
. The transformation defined in this tutorial is just for illustration purposes, it does little more than just extracting some variables from the tabular data and converting them to the appropriate format.
Define training step¶
Here we take as input one batch of (data, target), train the model and compute the loss function.
Note that the MedicalFolderDataset
class returns data
as a tuple of (images, demographics)
, where:
images
is adict
of{modality: image
} (after image transformations)demographics
is adict
of{column_name: values}
where the column names are taken from the demographics csv file while thetarget
is adict
of{modality: image
} (after target transformations).
In our case, the modality used is T1
for the input images, while the modality used for the target is label
. In this tutorial, we ignore the values of the demographics data during training because the UNet model only takes images as input. However, the code is provided for illustration purposes as it shows the recommended way to handle the associated tabular data.
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.logger import logger
from fedbiomed.common.data import DataManager, MedicalFolderDataset
import torch.nn as nn
from torch.optim import AdamW
from unet import UNet
class UNetTrainingPlan(TorchTrainingPlan):
def init_model(self, model_args):
model = self.Net(model_args)
return model
def init_optimizer(self):
optimizer = AdamW(self.model().parameters())
return optimizer
def init_dependencies(self):
# Here we define the custom dependencies that will be needed by our custom Dataloader
deps = ["from monai.transforms import (Compose, NormalizeIntensity, AddChannel, Resize, AsDiscrete)",
"import torch.nn as nn",
'import torch.nn.functional as F',
"from fedbiomed.common.data import MedicalFolderDataset",
'import numpy as np',
'from torch.optim import AdamW',
'from unet import UNet']
return deps
class Net(nn.Module):
# Init of UNetTrainingPlan
def __init__(self, model_args: dict = {}):
super().__init__()
self.CHANNELS_DIMENSION = 1
self.unet = UNet(
in_channels = model_args.get('in_channels',1),
out_classes = model_args.get('out_classes',2),
dimensions = model_args.get('dimensions',2),
num_encoding_blocks = model_args.get('num_encoding_blocks',5),
out_channels_first_layer = model_args.get('out_channels_first_layer',64),
normalization = model_args.get('normalization', None),
pooling_type = model_args.get('pooling_type', 'max'),
upsampling_type = model_args.get('upsampling_type','conv'),
preactivation = model_args.get('preactivation',False),
residual = model_args.get('residual',False),
padding = model_args.get('padding',0),
padding_mode = model_args.get('padding_mode','zeros'),
activation = model_args.get('activation','ReLU'),
initial_dilation = model_args.get('initial_dilation',None),
dropout = model_args.get('dropout',0),
monte_carlo_dropout = model_args.get('monte_carlo_dropout',0)
)
def forward(self, x):
x = self.unet.forward(x)
x = F.softmax(x, dim=self.CHANNELS_DIMENSION)
return x
@staticmethod
def get_dice_loss(output, target, epsilon=1e-9):
SPATIAL_DIMENSIONS = 2, 3, 4
p0 = output
g0 = target
p1 = 1 - p0
g1 = 1 - g0
tp = (p0 * g0).sum(dim=SPATIAL_DIMENSIONS)
fp = (p0 * g1).sum(dim=SPATIAL_DIMENSIONS)
fn = (p1 * g0).sum(dim=SPATIAL_DIMENSIONS)
num = 2 * tp
denom = 2 * tp + fp + fn + epsilon
dice_score = num / denom
return 1. - dice_score
@staticmethod
def demographics_transform(demographics: dict):
"""Transforms dict of demographics into data type for ML.
This function is provided for demonstration purposes, but
note that if you intend to use demographics data as part
of your model's input, you **must** provide a
`demographics_transform` function which at the very least
converts the demographics dict into a torch.Tensor.
Must return either a torch Tensor or something Tensor-like
that can be easily converted through the torch.as_tensor()
function."""
if isinstance(demographics, dict) and len(demographics) == 0:
# when input is empty dict, we don't want to transform anything
return demographics
# simple example: keep only some keys
keys_to_keep = ['HEIGHT', 'WEIGHT']
out = np.array([float(val) for key, val in demographics.items() if key in keys_to_keep])
# more complex: generate dummy variables for site name
# not ideal as it requires knowing the site names in advance
# could be better implemented with some preprocess
site_names = ['Guys', 'IOP', 'HH']
len_dummy_vars = len(site_names) + 1
dummy_vars = np.zeros(shape=(len_dummy_vars,))
site_name = demographics['SITE_NAME']
if site_name in site_names:
site_idx = site_names.index(site_name)
else:
site_idx = len_dummy_vars - 1
dummy_vars[site_idx] = 1.
return np.concatenate((out, dummy_vars))
def training_data(self):
# The training_data creates the Dataloader to be used for training in the general class Torchnn of fedbiomed
common_shape = (48, 60, 48)
training_transform = Compose([AddChannel(), Resize(common_shape), NormalizeIntensity(),])
target_transform = Compose([AddChannel(), Resize(common_shape), AsDiscrete(to_onehot=2)])
dataset = MedicalFolderDataset(
root=self.dataset_path,
data_modalities='T1',
target_modalities='label',
transform=training_transform,
target_transform=target_transform,
demographics_transform=UNetTrainingPlan.demographics_transform)
loader_arguments = { 'shuffle': True}
return DataManager(dataset, **loader_arguments)
def training_step(self, data, target):
#this function must return the loss to backward it
img = data[0]['T1']
demographics = data[1]
output = self.model().forward(img)
loss = UNetTrainingPlan.get_dice_loss(output, target['label'])
avg_loss = loss.mean()
return avg_loss
def testing_step(self, data, target):
img = data[0]['T1']
demographics = data[1]
target = target['label']
prediction = self.model().forward(img)
loss = UNetTrainingPlan.get_dice_loss(prediction, target)
avg_loss = loss.mean() # average per batch
return avg_loss
Prepare the experiment¶
model_args = {
'in_channels': 1,
'out_classes': 2,
'dimensions': 3,
'num_encoding_blocks': 3,
'out_channels_first_layer': 8,
'normalization': 'batch',
'upsampling_type': 'linear',
'padding': True,
'activation': 'PReLU',
}
training_args = {
'loader_args': { 'batch_size': 16, },
'epochs': 2,
'dry_run': False,
'log_interval': 2,
'test_ratio' : 0.1,
'test_on_global_updates': True,
'test_on_local_updates': True,
}
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage
tags = ['ixi-train']
num_rounds = 3
exp = Experiment(tags=tags,
model_args=model_args,
training_plan_class=UNetTrainingPlan,
training_args=training_args,
round_limit=num_rounds,
aggregator=FedAverage(),
tensorboard=True
)
Tensorboard setup¶
%load_ext tensorboard
from fedbiomed.researcher.environ import environ
tensorboard_dir = environ['TENSORBOARD_RESULTS_DIR']
%tensorboard --logdir "$tensorboard_dir"
On a Macbook Pro from 2015 with a 2,5 GHz Quad-Core Intel Core i7 processor and 16GB of DRAM, training for 3 rounds of 2 epochs each took about 30 minutes. The final training curves look like this:
Run the experiment¶
exp.run()
Save trained model to file
exp.training_plan().export_model('./trained_model')
Validate on a local holdout set¶
To ensure consistency and simplify our life, we try to reuse the already-available code as much as possible. Note that this process assumes that the held-out data is stored locally on the machine.
Create an instance of the global model¶
First, we create an instance of the model using the parameters from the latest aggregation round.
local_training_plan = UNetTrainingPlan()
local_model = local_training_plan.init_model(model_args)
for dependency_statement in local_training_plan.init_dependencies():
exec(dependency_statement)
local_model.load_state_dict(exp.aggregated_params()[exp.round_current()-1]['params'])
<All keys matched successfully>
Define a validation data loader¶
We extract the validation data loader from the training plan as well. This requires some knowledge about the internals of the MedicalFolderDataset
class. At the end of the process, calling the split
function with a ratio of 0 will return a data loader that loads all of the data.
from torch.utils.data import DataLoader
data_loaders = []
datasets = [{
'dataset_path' : '<febiomed-dir>/notebooks/data/Hospital-Centers/Guys/holdout/',
'dataset_parameters': {
'tabular_file': '<febiomed-dir>/notebooks/data/Hospital-Centers/Guys/holdout/participants.csv',
'index_col': 14
}
},
{
'dataset_path' : '<febiomed-dir>/notebooks/data/Hospital-Centers/HH/holdout/',
'dataset_parameters': {
'tabular_file': '<febiomed-dir>/notebooks/data/Hospital-Centers/HH/holdout/participants.csv',
'index_col': 14
}
},
{
'dataset_path' : '<febiomed-dir>/notebooks/data/Hospital-Centers/IOP/holdout/',
'dataset_parameters': {
'tabular_file': '<febiomed-dir>/notebooks/data/Hospital-Centers/IOP/holdout/participants.csv',
'index_col': 14
}
},
]
for dataset in datasets:
local_training_plan.dataset_path = dataset['dataset_path']
val_data_manager = local_training_plan.training_data()
val_data_manager._dataset.set_dataset_parameters(dataset['dataset_parameters'])
data_loaders.append(DataLoader(val_data_manager._dataset))
Compute the loss on validation images¶
import torch
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
local_model.eval()
losses = []
labels = []
for i, dl in enumerate(data_loaders):
losses_ = []
labels.append(f"Center {i+1}")
with torch.no_grad():
for i, ((images, demographics), targets) in enumerate(dl):
image = images['T1']
target = targets['label']
prediction = local_model.forward(image)
loss = UNetTrainingPlan.get_dice_loss(prediction, target)
losses_.append(loss.mean())
losses.append(losses_)
plt.subplot(111)
bxplt = plt.boxplot(losses,
vert=True,
patch_artist=True)
plt.title("Mean `dice loss` values on validation images")
plt.gca().xaxis.set_ticklabels(labels)
colors = ['pink', 'lightblue', 'lightgreen']
for patch, color in zip(bxplt['boxes'], colors):
patch.set_facecolor(color)
plt.show()
Visualize Training Loss and Testing Metrics¶
# Visualize training loss
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
monitor = exp.monitor()
metrics = monitor._metric_store
training_metrics = [('training', k) for k in list(metrics[list(metrics.keys())[0]]['training'].keys())]
testing_global_metrics = [('testing_global_updates', k) for k in list(metrics[list(metrics.keys())[0]]['testing_global_updates'].keys())]
testing_local_metrics = [('testing_local_updates', k) for k in list(metrics[list(metrics.keys())[0]]['testing_local_updates'].keys())]
metrics_ = [*training_metrics, *testing_local_metrics, *testing_global_metrics]
cols = len(metrics_)
fig, axes = plt.subplots(1, cols, figsize=( cols * 4, len(metrics) * 1.5))
for i, (node, store) in enumerate(metrics.items()):
title = ""
for k, (for_, m_) in enumerate(metrics_):
title = f"Metrics {for_}" if title != f"Metrics {for_}" else title
data = [i for k, l in store[for_].get(m_, {}).items() for i in l["values"]]
smoothed = gaussian_filter1d(data, sigma=1.5)
axes[k].plot(smoothed, label=f"Node {i+1}")
axes[k].set_title(title)
axes[k].set_ylabel(m_)
axes[k].set_xlabel("Iterations")
axes[k].legend()
fig.tight_layout()
plt.show()