Multi-Channel Variational Autoencoder¶
Goal of this Tutorial¶
This tutorial serves as an example on how to train a model with a custom number of channels. More specifically, it uses a Multi-channel Variational Autoencoder to encode and decode medical data with 5 different modalities.
VAE¶
The Variational Autoencoder is a latent variable model composed by one encoder and one decoder associated to a single channel. The latent distribution and the decoding distribution are implemented as follows:
$$q(\mathbf{z|x}) = \mathcal{N}(\mathbf{z|\mu_x; \Sigma_x})$$
$$p(\mathbf{x|z}) = \mathcal{N}(\mathbf{x|\mu_z; \Sigma_z})$$
They are Gaussians with moments parametrized by Neural Networks (or a linear transformation layer in a simple case).
For the variance networks output, it is more common and convenient to use $\log{\sigma^2}$. This is due to the fact that neural networks can output any real number, while the variance is strictly positive (${\sigma^2}>0)$.
MCVAE¶
The last part of this tutorial concerns the use of the multi-channel variational autoencoder, a more advanced method for the joint analysis and prediction of several modalities.
The MultiChannel VAE is built by stacking multiple VAEs and allowing the decoding distributions to be computed from every input channel.
The source code can be found in here: https://gitlab.inria.fr/epione_ML/mcvae
Installing the Requirements¶
Below, we install the mcvae model, which is necessary for this tutorial. The seaborn library is optionally used for plotting and illustration purposes.
%pip install -q git+https://gitlab.inria.fr/epione_ML/mcvae.git
%pip install seaborn
import pandas as pd
import os
import torch
Downloading the data¶
The data contains 5 different modalities which are:
- Volume: Structural MRI Brain Volumes of the patient.
- Demographics: Age, sex and years of education of the patient.
- Cognition: Cognitive scores of the patient. It contains scores from Clinical Dementia Rating, Alzheimer's Disease Assessment Scale, Mini-Mental State Examination, Rey Auditory Verbal Learning Test and Functional Activities Questionnaire.
- Apoe (Genetic risk): The count of APOE ε4 alleles (0, 1 or 2) of the patient, where higher indicates more risk.
- Fluid: CerebroSpinal Fluid Biomarkers of the patient, where it contains the baseline values of Amyloid-beta 42 (ABETA), Tau (TAU) and Phospho-tau (PTAU) proteins the patient has.
adni = pd.read_csv('https://gitlab.inria.fr/ssilvari/flhd/-/raw/master/heterogeneous_data/pseudo_adni.csv?inline=false')
print(f'Loaded {len(adni)} samples.')
normalize = lambda x: (x - x.mean(0))/x.std(0)
volume_cols = ['WholeBrain.bl', 'Ventricles.bl', 'Hippocampus.bl', 'MidTemp.bl', 'Entorhinal.bl']
demog_cols = ['SEX', 'AGE', 'PTEDUCAT']
cognition_cols = ['CDRSB.bl', 'ADAS11.bl', 'MMSE.bl', 'RAVLT.immediate.bl', 'RAVLT.learning.bl', 'RAVLT.forgetting.bl', 'FAQ.bl']
apoe_cols = ['APOE4']
fluid_cols = ['ABETA.MEDIAN.bl', 'PTAU.MEDIAN.bl', 'TAU.MEDIAN.bl']
adni_cols = [volume_cols, demog_cols, cognition_cols, apoe_cols, fluid_cols]
for cols in adni_cols:
adni[cols] = (adni[cols] - adni[cols].mean())/adni[cols].std()
# Creating a list with multimodal data
data_adni = [adni[cols].values for cols in adni_cols]
# Transform as a pytorch Tensor for compatibility
data_adni = [torch.Tensor(_) for _ in data_adni]
print(f'We have {len(data_adni)} channels in total as an input for the model')
Utility function to divide the data into the n number of data centers (hospitals) and leave a certain ratio for each center as holdout for later validation.
train_data_path = f'./data/train'
holdout_data_path = f'./data/holdout'
def prepare_data_nth_center(n: int, offset: int, n_samples_train: int, n_samples_holdout):
os.makedirs(train_data_path, exist_ok=True)
os.makedirs(holdout_data_path, exist_ok=True)
train_data_df = adni.iloc[offset:offset+n_samples_train,:]
train_data_df.to_csv(train_data_path + f'/dataset{n}.csv')
test_data_df = adni.iloc[offset+n_samples_train:offset+n_samples_train+n_samples_holdout,:]
test_data_df.to_csv(holdout_data_path + f'/dataset{n}.csv')
# Number of centers to divide the data
n_centers = 2
n_samples_total = len(adni)
n_samples_per_center = n_samples_total // n_centers
# Holdout ratio
holdout_ratio = 0.1
n_holdout_samples_per_center = int(n_samples_per_center*holdout_ratio)
n_train_samples_per_center = n_samples_per_center - n_holdout_samples_per_center
last_offset = 0
#for i in range(n_centers-1):
for i in range(n_centers):
prepare_data_nth_center(n=i,
offset=last_offset,
n_samples_train=n_train_samples_per_center,
n_samples_holdout=n_holdout_samples_per_center)
last_offset += n_train_samples_per_center+n_holdout_samples_per_center
print(f'Center {i}: {n_train_samples_per_center} train samples')
print(f'Current offset: {last_offset}')
Add a dataset to the first node (hospital) with the following command
fedbiomed node -p CUSTOM/PATH/TO/NODE dataset add
When prompted for data type, select 1) csv
Please select the data type that you're configuring:
1) csv
2) default
3) mednist
4) images
5) medical-folder
6) flamby
select: 1
For name and description you may input whatever you want.
For tags it is VERY important to input adni-train The Experiment will later search for the available data using the tag(s) provided.
For the path of the file, input
/PATH/TO/NODE/data/train/dataset0.csv
Likewise, return all the same steps for the N number of nodes that you want to add.
fedbiomed node -p CUSTOM/PATH/TO/NODE_N dataset add
/PATH/TO/NODE_N/data/train/dataset0.csv
Finally. start the nodes using the command:
fedbiomed node -p CUSTOM/PATH/TO/NODE start
Creating the Training Plan¶
To train our custom mcvae, we should initialize the model in the init_model function, and use a Dataset class wrapper around our data in the training_data function. To do both, we define an auxiliary function get_channels to customize and specify the channels our data has.
Next, we define our second helper function to create the 5 channels we have as Torch Tensors. We initialize the model and it's parameters. We create a dummy data, again with 5 channels to initialize the dimensionality of our model.
For the training_data function, we inherit from the Dataset class and create our own Dataset class. This is especially done to override the getitem function which is fundamental for our training plan. It defines what data item would be retrieved at each training step to train one sample during the training loop. These samples are then batched according to the batch_size parameter.
Finally, the training step computes the loss by using:
- q: The approximate posterior value $q(z|x)$ the encoder calculates from the (generally Gaussian) distribution of the data over the latent variable z.
- x: The input data in tensor format.
- p: The likelihood value $p(x|z)$ the decoder calculates by reconstructing from z.
- KL: KL divergence
- LL: Log likelihood
The loss is calculated as the difference of kl to ll. Their formulas can be seen below:
$$\mathcal{L}_{\text{KL}} = \frac{1}{2} \sum_{i=1}^c \left( \mu_i^2 + \sigma_i^2 - \ln \sigma_i^2 - 1 \right)$$
$$\mathcal{L}_{\text{LL}} = -\frac{1}{2\sigma^2} \| x - \hat{x} \|^2 + \text{const}$$
Important Warning¶
The mcvae module tries to detect and utilize the gpu in the system by default. If that is not preferred, the DEVICE variable can be set to cpu as seen below.
from mcvae.gpu import DEVICE
print(DEVICE)
#DEVICE = torch.device('cpu')
from fedbiomed.common.training_plans import TorchTrainingPlan
class MCVAETrainingPlan(TorchTrainingPlan):
@staticmethod
def get_channels():
channel_1 = ['WholeBrain.bl', 'Ventricles.bl', 'Hippocampus.bl', 'MidTemp.bl', 'Entorhinal.bl']
channel_2 = ['SEX', 'AGE', 'PTEDUCAT']
channel_3 = ['CDRSB.bl', 'ADAS11.bl', 'MMSE.bl', 'RAVLT.immediate.bl', 'RAVLT.learning.bl', 'RAVLT.forgetting.bl', 'FAQ.bl']
channel_4 = ['APOE4']
channel_5 = ['ABETA.MEDIAN.bl', 'PTAU.MEDIAN.bl', 'TAU.MEDIAN.bl']
return channel_1, channel_2, channel_3, channel_4, channel_5
@staticmethod
def get_data_as_multichannel_tensor_dataset(df):
"""Takes a dataframe, splits it into multiple channels and parse each channel as a tensor"""
channel_1, channel_2, channel_3, channel_4, channel_5 = MCVAETrainingPlan.get_channels()
df = (df - df.mean())/df.std()
def as_tensor(cols):
tensor = torch.tensor(df[cols].values).float()
return tensor
return [as_tensor(channel_1), as_tensor(channel_2), as_tensor(channel_3), as_tensor(channel_4), as_tensor(channel_5)]
def init_model(self, model_args):
channels = MCVAETrainingPlan.get_channels()
dummy_data = [torch.zeros((1, len(ch))).to('cpu') for ch in channels]
# print(dummy_data[0].device)
vaeclass = VAE
return Mcvae(data=dummy_data,
lat_dim=model_args.get('lat_dim', 1),
vaeclass=vaeclass,
sparse=model_args.get('sparse', False))
def init_optimizer(self, optimizer_args):
optimizer = Adam(self.model().parameters(), lr=optimizer_args.get('lr', 0.001))
return optimizer
def init_dependencies(self):
deps = [
'from mcvae.models import Mcvae, ThreeLayersVAE, VAE',
'from torch.optim import Adam',
'from torchvision import datasets, transforms',
'from torch.utils.data import Dataset',
'from fedbiomed.common.logger import logger',
'import numpy as np',
'import pandas as pd']
return deps
def training_data(self):
df = pd.read_csv(self.dataset_path)
class myDataset(Dataset):
def __init__(self, data):
self._data = data
def __len__(self):
return len(self._data)
def __getitem__(self, idx):
df_ = self._data.iloc[idx,:]
return MCVAETrainingPlan.get_data_as_multichannel_tensor_dataset(df_), []
return DataManager(myDataset(df))
def training_step(self, data, target):
output = self.model().forward(data)
q = output['q']
x = output['x']
p = output['p']
kl = self.model().compute_kl(q)
kl *= self.model().beta
ll = self.model().compute_ll(p=p, x=x)
return kl - ll
We initialize the model arguments for MCVAE and the training arguments
model_args = {
'lat_dim': 1,
'sparse': False
}
training_args = {
'loader_args': { 'batch_size': 64, },
'optimizer_args': {'lr': 1e-4},
'num_updates': 50,
'log_interval': 25,
'test_ratio': 0.0,
'test_on_global_updates': False,
'test_on_local_updates': False,
'random_seed': 424242,
}
We create an experiment. We select Federated Averaging as aggregator method and use the tags that we initially used on our dataset.
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage
tags = ['adni-train']
num_rounds = 50
exp = Experiment(tags=tags,
model_args=model_args,
training_plan_class=MCVAETrainingPlan,
training_args=training_args,
round_limit=num_rounds,
aggregator=FedAverage(),
tensorboard=True
)
%load_ext tensorboard
from fedbiomed.researcher.config import config
tensorboard_dir = './tensorboard_results'
%tensorboard --logdir "$tensorboard_dir"
exp.run()
We import some additional libraries for plotting
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
sns.set()
decoding_weights_dict = {k: w.detach().numpy() for k, w in aggregated_model.state_dict().items() if 'W_out.weight' in k}
We plot the Z values for Volume biomarkers
lat_dim_names = [f'$Z_{{{i}}}$' for i in range(model_args['lat_dim'])]
col_names = lat_dim_names + ["biomarker"]
weights = pd.DataFrame()
channels = MCVAETrainingPlan.get_channels()
for channel_i, weights_i in enumerate(decoding_weights_dict.values()):
channel_df = pd.DataFrame(np.concatenate((weights_i, np.array(channels[channel_i]).reshape(-1, 1)), axis=1),
columns=lat_dim_names + ["biomarker"])
channel_df['channel'] = channel_i + 1
weights = pd.concat((weights, channel_df))
weights["$Z_{0}$"] = weights["$Z_{0}$"].astype('float32')
weights.head()
weights_melt = weights.melt(id_vars=['biomarker', 'channel'], var_name='latent_var')
weights_melt.sample()
sns.catplot(data=weights_melt, x='biomarker', y='value', hue='latent_var', kind='bar', col='channel', col_wrap=1, aspect=2.5, sharex=False, palette='Blues_r')
plt.show()
We present two alternative methods for prediction.
The first one is to predict a channel/modality from the whole data.
The second is to predict a channel from a specific channel.
# Predict volumes (channel 0) from cognition (channel 2)
# Solution 1
with torch.no_grad():
# Encode everything
q = aggregated_model.encode(data_adni)
# Take the mean of every encoded distribution q
z = [qi.loc for qi in q]
# Decode all
p = aggregated_model.decode(z)
# Extract what you need: p(x|z) or p[x][z] or p[decoder output channel][encoder input chanenl]
decoding_volume_from_cognition = p[0][2].loc.data.numpy()
plt.figure(figsize=(12, 28))
for i in range(len(volume_cols)):
plt.subplot(5,1,i+1)
plt.scatter(decoding_volume_from_cognition[:,i], data_adni[0][:,i])
plt.title('reconstruction ' + volume_cols[i])
plt.xlabel('predicted')
plt.ylabel('target')
plt.show()
Predict the Volume from Cognition.
# Solution 2
# Encode the cognition (ch 2)
q2 = aggregated_model.vae[2].encode(data_adni[2])
# Take the mean of q (location in pytorch jargon)
z2 = q2.loc
# Decode through the brain volumes decoder (ch 0)
p0 = aggregated_model.vae[0].decode(z2)
# Take the mean
decoding_volume_from_cognition = p0.loc.data.numpy()
plt.figure(figsize=(12, 28))
for i in range(len(volume_cols)):
plt.subplot(5,1,i+1)
plt.scatter(decoding_volume_from_cognition[:,i], data_adni[0][:,i])
plt.title('reconstruction ' + volume_cols[i])
plt.xlabel('predicted')
plt.ylabel('target')
plt.show()
Save the model.
torch.save({
'model_state_dict': model.state_dict(),
'training_args': {
'num_rounds': training_args['num_rounds'],
'num_updates': training_args['num_updates'],
'loader_args': training_args['loader_args'],
'optimizer_args': training_args['optimizer_args'],
'model_args': training_args['model_args'],
}
}, "model.pth")