End-to-end Privacy Preserving Training and Inference on Medical Data¶
Introduction¶
In this tutorial, we will demonstrate the process of privacy-preserving training using Fed-BioMed, which leverages federated learning with secure aggregation. Subsequently, we will deploy the final model obtained through federated learning in a privacy-preserving manner using the Concrete-ML library. This approach allows us to achieve privacy-preserving inference through a software-as-a-service (SaaS) model.
The selected dataset originates from a medical task assigned by Flamby, specifically the FedHeart Disease dataset. For more detailed information about the dataset, please refer to the provided link.
Install Concrete-ML¶
This tutorial assumes you have Concrete-ML installed in your Fed-BioMed researcher environment.
If needed, you may install it with: pip install concrete-ml
!!! important "Version mismatch" The installation of concrete-ml can force a different PyTorch version, which may be incompatible with torchvision or other packages. If you encounter version-mismatch errors after installing concrete-ml, recreate or fix your environment so that PyTorch and torchvision match.For example, reinstall fedbiomed (or reinstall PyTorch and torchvision with compatible versions). Pin specific package versions if necessary.
Note for MacOS users with ARM chips¶
If you have a recent Mac machine with Apple Silicon (ARM chips), then you may experience kernel failure in Section 3.2 of this notebook.
To overcome this issue, you need to rebuild your conda environment with a native python executable, before installing Concrete-ML.
After setting export CONDA_SUBDIR=osx-arm64 you can re install fedbiomed and Concrete-ML.
!pip install concrete-ml
FLamby configuration - Download FedHeart¶
The configuration of nodes for the FedHeart dataset in FLamby is explained in the FLamby integration tutorial. Please check that tutorial before continuing with the steps below.
Important — Compatible Python version
Please use Python 3.11 or newer. Using Python 3.11+ helps avoid package version conflicts when installing required libraries (for example, Fed-BioMed and Concrete-ML). !!! important "Compatiple Python Version" Please make sure to use greater than Python 3.11 to be able to avoid package version errors.
1. Fed-BioMed¶
Configuring the Fed-BioMed training plan involves specifying the machine learning model, defining the loss function, and identifying the necessary dependencies. This ensures a clear and well-defined setup for the training process.
%load_ext tensorboard
import torch
import numpy as np
import torch.nn as nn
from concrete.ml.torch.compile import compile_torch_model
from flamby.datasets.fed_heart_disease import FedHeartDisease
from torch.utils.data import DataLoader
from torch.nn.modules.loss import _Loss
from fedbiomed.common.dataset import CustomDataset
from fedbiomed.common.datamanager import DataManager
from fedbiomed.common.training_plans import TorchTrainingPlan
class FedHeartTrainingPlan(TorchTrainingPlan):
class Baseline(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(13, 16)
self.fc2 = nn.Linear(16, 2)
self.act = nn.LeakyReLU()
def forward(self, x):
x = self.act(self.fc1(x))
x = self.fc2(x)
return x
class BaselineLoss(_Loss):
def __init__(self):
super().__init__()
self.ce = torch.nn.CrossEntropyLoss()
def forward(self, prediction: torch.Tensor, target: torch.Tensor):
target = torch.squeeze(target, dim=1).type(torch.long)
return self.ce(prediction, target)
def init_model(self, model_args):
return self.Baseline()
def init_optimizer(self, optimizer_args):
return torch.optim.AdamW(self.model().parameters(), lr=optimizer_args["lr"])
def init_dependencies(self):
return ["from flamby.datasets.fed_heart_disease import FedHeartDisease",
"from torch.nn.modules.loss import _Loss",
"from fedbiomed.common.datamanager import DataManager",
"from fedbiomed.common.dataset import CustomDataset"
]
def training_step(self, data, target):
logits = self.model().forward(data)
return self.BaselineLoss().forward(logits, target)
class MyFedHeartDataset(CustomDataset):
def read(self):
"""Read FLamby data"""
# Read json file that is deployed on the node
import json
with open(self.path) as f:
flamby_data = json.load(f)
# Create data file
self.data = FedHeartDisease(
center=flamby_data["center"],
data_path=flamby_data["dataset-path"]
)
print(self.data)
def get_item(self, item):
"""Get item"""
return self.data[item]
def __len__(self):
"""Dataset length"""
return len(self.data)
def training_data(self, batch_size=2):
dataset = self.MyFedHeartDataset()
train_kwargs = {'batch_size': batch_size, 'shuffle': True}
return DataManager(dataset, **train_kwargs)
batch_size = 8
num_updates = 10
num_rounds = 50
training_args = {
'optimizer_args': {
'lr': 5e-4,
},
'loader_args': {
'batch_size': batch_size,
},
'random_seed':42,
'num_updates': num_updates,
'dry_run': False,
'log_interval': 2,
'test_ratio' : 0.0,
'test_on_global_updates': False,
'test_on_local_updates': False,
}
model_args = {}
2. Federated Learning Training with SecAgg¶
In this section an experiment object will be created and federated training will be done using secure aggregation. Please make sure that two nodes that has the datasets deployed are up and running before running your experiment.
fedbiomed node -p ./node-1 start
fedbiomed node -p ./node-2 start
``
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage
tags = ['flamby', 'fed_heart_disease']
exp_sec_agg = Experiment(tags=tags,
training_plan_class=FedHeartTrainingPlan,
training_args=training_args,
model_args=model_args,
round_limit=num_rounds,
aggregator=FedAverage(),
secagg=True,
tensorboard=True
)
tensorboard_dir = exp_sec_agg.tensorboard_results_path
tensorboard --logdir "$tensorboard_dir"
In our example, the training loss curve for the 4 nodes looks like this:
exp_sec_agg.run()
3. Inference¶
We now have access to the weights of the final model, after secure encrypted training.
fed_sec_agg_model = exp_sec_agg.training_plan().model()
fed_sec_agg_model.eval()
DATASET_TEST_PATH = "<path-to-your-falmby-dataset>/fed_heart_disease"
test_dataset = FedHeartDisease(center=0,pooled=True, train=False, data_path=DATASET_TEST_PATH)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
3.1. Inference with torch using the plaintext model¶
First, we establish a baseline by evaluating the plaintext (unencrypted) model on a held-out test dataset.
def test_torch(net, test_loader):
"""Test the network: measure accuracy on the test set."""
# Freeze normalization layers
net.eval()
all_y_pred = np.zeros((len(test_loader)), dtype=np.int64)
all_targets = np.zeros((len(test_loader)), dtype=np.int64)
# Iterate over the batches
idx = 0
for data, target in test_loader:
# Accumulate the ground truth labels
endidx = idx + target.shape[0]
all_targets[idx:endidx] = target.numpy()
# Run forward and get the predicted class id
logits = torch.sigmoid(net(data))
output = logits.argmax(1).detach().numpy()
all_y_pred[idx:endidx] = output
idx += target.shape[0]
# Print out the accuracy as a percentage
n_correct = np.sum(all_targets == all_y_pred)
print(
f"Test accuracy over plaintext model: "
f"{n_correct / len(test_loader) * 100:.2f}%"
)
test_torch(fed_sec_agg_model, test_dataloader)
In our example we reach an accuracy over the plaintext model of 77.56%
3.2. Inference with Concrete-ML using the encrypted model¶
Using Zama's Concrete-ML library, we now show that a similar performance can be achieved by performing inference on the encrypted model.
We have therefore achieved fully secure, end-to-end encrypted training and inference.
def test_with_concrete(quantized_module, test_loader, use_sim):
"""Test a neural network that is quantized and compiled with Concrete ML."""
# Casting the inputs into int64 is recommended
all_y_pred = np.zeros((len(test_loader)), dtype=np.int64)
all_targets = np.zeros((len(test_loader)), dtype=np.int64)
# Iterate over the test batches and accumulate predictions and ground truth labels in a vector
idx = 0
for data, target in test_loader:
data = data.numpy()
target = target.numpy()
fhe_mode = "simulate" if use_sim else "execute"
# Quantize the inputs and cast to appropriate data type
logits = torch.tensor(quantized_module.forward(data, fhe=fhe_mode), requires_grad=False)
endidx = idx + target.shape[0]
all_targets[idx:endidx] = target
# Get the predicted class id and accumulate the predictions
y_pred = torch.sigmoid(logits).argmax(1).numpy()
all_y_pred[idx:endidx] = y_pred
# Update the index
idx += target.shape[0]
n_correct = np.sum(all_targets == all_y_pred)
print(
f"Test accuracy over encrypted model: "
f"{n_correct / len(test_loader) * 100:.2f}%"
)
# concrete ml is using the traceback,
# while fed-biomed for logging reasons fixs it to 3, to use concrete-ml we reset to the default value
import sys
sys.tracebacklimit = 30
n_bits = 6
compile_set = np.random.randint(0, 10, (100, 13)).astype(float)
q_module = compile_torch_model(fed_sec_agg_model, compile_set, rounding_threshold_bits=6)
test_with_concrete(q_module, test_dataloader, True)
In our example we reach an accuracy over the encrypted model of 76.38%
The loss of accuracy due to encryption during inference is neglectable!