• Home
  • User Documentation
  • About
  • More
    • Funding
    • News
    • Contributors
    • Users
    • Roadmap
    • How to Cite Us
    • Contact Us
  • Home
  • User Documentation
  • About
  • More
    • Funding
    • News
    • Contributors
    • Users
    • Roadmap
    • How to Cite Us
    • Contact Us
  • Getting Started
    • What's Fed-BioMed
    • Fedbiomed Architecture
    • Fedbiomed Workflow
    • Installation
    • Basic Example
    • Configuration
  • Tutorials
    • PyTorch
      • PyTorch MNIST Basic Example
      • How to Create Your Custom PyTorch Training Plan
      • PyTorch Used Cars Dataset Example
      • Transfer-learning in Fed-BioMed tutorial
      • PyTorch aggregation methods in Fed-BioMed
    • MONAI
      • Federated 2d image classification with MONAI
      • Federated 2d XRay registration with MONAI
    • Scikit-Learn
      • MNIST classification with Scikit-Learn Classifier (Perceptron)
      • Fed-BioMed to train a federated SGD regressor model
      • Implementing other Scikit Learn models for Federated Learning
    • Optimizers
      • Advanced optimizers in Fed-BioMed
    • FLamby
      • Introduction
      • FLamby in Fed-BioMed
    • Advanced
      • In Depth Experiment Configuration
      • PyTorch model training using a GPU
      • Breakpoints
    • Security
      • Using Differential Privacy with OPACUS on Fed-BioMed
      • Local and Central DP with Fed-BioMed: MONAI 2d image registration
      • Training Process with Training Plan Management
      • Training with Secure Aggregation
      • End-to-end Privacy Preserving Training and Inference on Medical Data
    • Biomedical data
      • Brain Segmentation
  • User Guide
    • Glossary
    • Datasets
      • Introduction
      • Default Datasets
      • Image Datasets
      • Tabular Datasets
      • Medical Datasets
      • Adding your Custom Dataset
      • Adding a Native Dataset
      • Applying Transformations
    • Deployment
      • Introduction
      • VPN Deployment
      • Network matrix
      • Security model
    • Node
      • Configuring Nodes
      • Deploying Datasets
      • Training Plan Management
      • Using GPU
      • Node GUI
    • Researcher
      • Training Plan
      • Training Data
      • Experiment
      • Aggregation
      • Listing Datasets and Selecting Nodes
      • Model Validation on the Node Side
      • Tensorboard
    • Optimization
    • Secure Aggregation
      • Introduction
      • Configuration
      • Managing Secure Aggregation in Researcher
  • Developer
    • API Reference
      • Common
        • Certificate Manager
        • CLI
        • Config
        • Constants
        • Data
        • DB
        • Exceptions
        • IPython
        • Json
        • Logger
        • Message
        • Metrics
        • Model
        • MPC controller
        • Optimizers
        • Privacy
        • Secagg
        • Secagg Manager
        • Serializer
        • Singleton
        • Synchro
        • TasksQueue
        • TrainingPlans
        • TrainingArgs
        • Utils
        • Validator
      • Node
        • CLI
        • CLI Utils
        • Config
        • DatasetManager
        • HistoryMonitor
        • Node
        • NodeStateManager
        • Requests
        • Round
        • Secagg
        • Secagg Manager
        • TrainingPlanSecurityManager
      • Researcher
        • Aggregators
        • CLI
        • Config
        • Datasets
        • Federated Workflows
        • Filetools
        • Jobs
        • Monitor
        • NodeStateAgent
        • Requests
        • Secagg
        • Strategies
      • Transport
        • Client
        • Controller
        • NodeAgent
        • Server
    • Usage and Tools
    • Continuous Integration
    • Definition of Done
    • Development Environment
    • Testing in Fed-BioMed
    • RPC Protocol and Messages
  • FAQ & Troubleshooting
Download Notebook

End-to-end Privacy Preserving Training and Inference on Medical Data¶

Introduction¶

In this tutorial, we will demonstrate the process of privacy-preserving training using Fed-BioMed, which leverages federated learning with secure aggregation. Subsequently, we will deploy the final model obtained through federated learning in a privacy-preserving manner using the Concrete-ML library. This approach allows us to achieve privacy-preserving inference through a software-as-a-service (SaaS) model.

The selected dataset originates from a medical task assigned by Flamby, specifically the FedHeart Disease dataset. For more detailed information about the dataset, please refer to the provided link.

Install Concrete-ML¶

This tutorial assumes you have Concrete-ML installed in your Fed-BioMed researcher environment.

If needed, you may install it with: pip install concrete-ml

!!! important "Version mismatch" The installation of concrete-ml can force a different PyTorch version, which may be incompatible with torchvision or other packages. If you encounter version-mismatch errors after installing concrete-ml, recreate or fix your environment so that PyTorch and torchvision match.For example, reinstall fedbiomed (or reinstall PyTorch and torchvision with compatible versions). Pin specific package versions if necessary.

Note for MacOS users with ARM chips¶

If you have a recent Mac machine with Apple Silicon (ARM chips), then you may experience kernel failure in Section 3.2 of this notebook.

To overcome this issue, you need to rebuild your conda environment with a native python executable, before installing Concrete-ML.

After setting export CONDA_SUBDIR=osx-arm64 you can re install fedbiomed and Concrete-ML.

In [ ]:
Copied!
!pip install concrete-ml
!pip install concrete-ml

Identify patients with heart disease¶

We rely on the FedHeart dataset and task from FLamby, in which tabular patient data as input is used to predict the presence or absence of heart disease. More details can be found in FLamby's paper.

FLamby configuration - Download FedHeart¶

The configuration of nodes for the FedHeart dataset in FLamby is explained in the FLamby integration tutorial. Please check that tutorial before continuing with the steps below.

Important — Compatible Python version

Please use Python 3.11 or newer. Using Python 3.11+ helps avoid package version conflicts when installing required libraries (for example, Fed-BioMed and Concrete-ML). !!! important "Compatiple Python Version" Please make sure to use greater than Python 3.11 to be able to avoid package version errors.

1. Fed-BioMed¶

Configuring the Fed-BioMed training plan involves specifying the machine learning model, defining the loss function, and identifying the necessary dependencies. This ensures a clear and well-defined setup for the training process.

In [ ]:
Copied!
%load_ext tensorboard
%load_ext tensorboard
In [ ]:
Copied!
import torch 
import numpy as np
import torch.nn as nn


from concrete.ml.torch.compile import compile_torch_model
from flamby.datasets.fed_heart_disease import FedHeartDisease

from torch.utils.data import DataLoader
from torch.nn.modules.loss import _Loss

from fedbiomed.common.dataset import CustomDataset             
from fedbiomed.common.datamanager import DataManager               
from fedbiomed.common.training_plans import TorchTrainingPlan

class FedHeartTrainingPlan(TorchTrainingPlan):
    
    class Baseline(nn.Module):
        
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(13, 16)
            self.fc2 = nn.Linear(16, 2)
            self.act = nn.LeakyReLU()
        def forward(self, x):
            x = self.act(self.fc1(x))
            x = self.fc2(x)
            return x
        
    class BaselineLoss(_Loss):
        def __init__(self):
            super().__init__()
            self.ce = torch.nn.CrossEntropyLoss()

        def forward(self, prediction: torch.Tensor, target: torch.Tensor):
            target = torch.squeeze(target, dim=1).type(torch.long)
            return self.ce(prediction, target)
    
    def init_model(self, model_args):
        return self.Baseline()

    def init_optimizer(self, optimizer_args):
        return torch.optim.AdamW(self.model().parameters(), lr=optimizer_args["lr"])

    def init_dependencies(self):
        return ["from flamby.datasets.fed_heart_disease import FedHeartDisease",
                "from torch.nn.modules.loss import _Loss",
                "from fedbiomed.common.datamanager import DataManager",
                "from fedbiomed.common.dataset import CustomDataset"
                ]

    def training_step(self, data, target):
        logits = self.model().forward(data)
        return self.BaselineLoss().forward(logits, target)

    class MyFedHeartDataset(CustomDataset):

        def read(self):
            """Read FLamby data"""            
            
            # Read json file that is deployed on the node
            import json
            with open(self.path) as f:
                flamby_data = json.load(f)

            # Create data file
            self.data = FedHeartDisease(
                center=flamby_data["center"], 
                data_path=flamby_data["dataset-path"]
            )

            print(self.data)

        def get_item(self, item):
            """Get item"""
            return self.data[item]
        
        def __len__(self):
            """Dataset length"""
            return len(self.data)

    def training_data(self, batch_size=2):
        dataset = self.MyFedHeartDataset()
        train_kwargs = {'batch_size': batch_size, 'shuffle': True}
        return DataManager(dataset, **train_kwargs)
import torch import numpy as np import torch.nn as nn from concrete.ml.torch.compile import compile_torch_model from flamby.datasets.fed_heart_disease import FedHeartDisease from torch.utils.data import DataLoader from torch.nn.modules.loss import _Loss from fedbiomed.common.dataset import CustomDataset from fedbiomed.common.datamanager import DataManager from fedbiomed.common.training_plans import TorchTrainingPlan class FedHeartTrainingPlan(TorchTrainingPlan): class Baseline(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(13, 16) self.fc2 = nn.Linear(16, 2) self.act = nn.LeakyReLU() def forward(self, x): x = self.act(self.fc1(x)) x = self.fc2(x) return x class BaselineLoss(_Loss): def __init__(self): super().__init__() self.ce = torch.nn.CrossEntropyLoss() def forward(self, prediction: torch.Tensor, target: torch.Tensor): target = torch.squeeze(target, dim=1).type(torch.long) return self.ce(prediction, target) def init_model(self, model_args): return self.Baseline() def init_optimizer(self, optimizer_args): return torch.optim.AdamW(self.model().parameters(), lr=optimizer_args["lr"]) def init_dependencies(self): return ["from flamby.datasets.fed_heart_disease import FedHeartDisease", "from torch.nn.modules.loss import _Loss", "from fedbiomed.common.datamanager import DataManager", "from fedbiomed.common.dataset import CustomDataset" ] def training_step(self, data, target): logits = self.model().forward(data) return self.BaselineLoss().forward(logits, target) class MyFedHeartDataset(CustomDataset): def read(self): """Read FLamby data""" # Read json file that is deployed on the node import json with open(self.path) as f: flamby_data = json.load(f) # Create data file self.data = FedHeartDisease( center=flamby_data["center"], data_path=flamby_data["dataset-path"] ) print(self.data) def get_item(self, item): """Get item""" return self.data[item] def __len__(self): """Dataset length""" return len(self.data) def training_data(self, batch_size=2): dataset = self.MyFedHeartDataset() train_kwargs = {'batch_size': batch_size, 'shuffle': True} return DataManager(dataset, **train_kwargs)
In [ ]:
Copied!
batch_size = 8
num_updates = 10
num_rounds = 50
batch_size = 8 num_updates = 10 num_rounds = 50
In [ ]:
Copied!
training_args = {
    'optimizer_args': {
        'lr': 5e-4,
    },
    'loader_args': {
        'batch_size': batch_size,
    },
    'random_seed':42,
    'num_updates': num_updates,
    'dry_run': False,
    'log_interval': 2,
    'test_ratio' : 0.0,
    'test_on_global_updates': False,
    'test_on_local_updates': False,
}

model_args = {}
training_args = { 'optimizer_args': { 'lr': 5e-4, }, 'loader_args': { 'batch_size': batch_size, }, 'random_seed':42, 'num_updates': num_updates, 'dry_run': False, 'log_interval': 2, 'test_ratio' : 0.0, 'test_on_global_updates': False, 'test_on_local_updates': False, } model_args = {}

2. Federated Learning Training with SecAgg¶

In this section an experiment object will be created and federated training will be done using secure aggregation. Please make sure that two nodes that has the datasets deployed are up and running before running your experiment.

fedbiomed node -p ./node-1 start
fedbiomed node -p ./node-2 start
``
In [ ]:
Copied!
from fedbiomed.researcher.experiment import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags = ['flamby', 'fed_heart_disease']

exp_sec_agg = Experiment(tags=tags,
                 training_plan_class=FedHeartTrainingPlan,
                 training_args=training_args,
                 model_args=model_args,
                 round_limit=num_rounds,
                 aggregator=FedAverage(),
                 secagg=True,
                 tensorboard=True
                )
from fedbiomed.researcher.experiment import Experiment from fedbiomed.researcher.aggregators.fedavg import FedAverage tags = ['flamby', 'fed_heart_disease'] exp_sec_agg = Experiment(tags=tags, training_plan_class=FedHeartTrainingPlan, training_args=training_args, model_args=model_args, round_limit=num_rounds, aggregator=FedAverage(), secagg=True, tensorboard=True )
In [ ]:
Copied!
tensorboard_dir = exp_sec_agg.tensorboard_results_path
tensorboard --logdir "$tensorboard_dir"
tensorboard_dir = exp_sec_agg.tensorboard_results_path tensorboard --logdir "$tensorboard_dir"

In our example, the training loss curve for the 4 nodes looks like this: image.png

In [ ]:
Copied!
exp_sec_agg.run()
exp_sec_agg.run()

3. Inference¶

We now have access to the weights of the final model, after secure encrypted training.

In [ ]:
Copied!
fed_sec_agg_model = exp_sec_agg.training_plan().model()
fed_sec_agg_model.eval()
fed_sec_agg_model = exp_sec_agg.training_plan().model() fed_sec_agg_model.eval()
In [ ]:
Copied!
DATASET_TEST_PATH = "<path-to-your-falmby-dataset>/fed_heart_disease"
test_dataset = FedHeartDisease(center=0,pooled=True, train=False, data_path=DATASET_TEST_PATH)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)
DATASET_TEST_PATH = "/fed_heart_disease" test_dataset = FedHeartDisease(center=0,pooled=True, train=False, data_path=DATASET_TEST_PATH) test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

3.1. Inference with torch using the plaintext model¶

First, we establish a baseline by evaluating the plaintext (unencrypted) model on a held-out test dataset.

In [ ]:
Copied!
def test_torch(net, test_loader):
    """Test the network: measure accuracy on the test set."""

    # Freeze normalization layers
    net.eval()

    all_y_pred = np.zeros((len(test_loader)), dtype=np.int64)
    all_targets = np.zeros((len(test_loader)), dtype=np.int64)

    # Iterate over the batches
    idx = 0
    for data, target in test_loader:
        # Accumulate the ground truth labels
        endidx = idx + target.shape[0]
        all_targets[idx:endidx] = target.numpy()

        # Run forward and get the predicted class id
        logits = torch.sigmoid(net(data))
        output = logits.argmax(1).detach().numpy()
        all_y_pred[idx:endidx] = output

        idx += target.shape[0]

    # Print out the accuracy as a percentage
    n_correct = np.sum(all_targets == all_y_pred)
    print(
        f"Test accuracy over plaintext model: "
        f"{n_correct / len(test_loader) * 100:.2f}%"
    )
def test_torch(net, test_loader): """Test the network: measure accuracy on the test set.""" # Freeze normalization layers net.eval() all_y_pred = np.zeros((len(test_loader)), dtype=np.int64) all_targets = np.zeros((len(test_loader)), dtype=np.int64) # Iterate over the batches idx = 0 for data, target in test_loader: # Accumulate the ground truth labels endidx = idx + target.shape[0] all_targets[idx:endidx] = target.numpy() # Run forward and get the predicted class id logits = torch.sigmoid(net(data)) output = logits.argmax(1).detach().numpy() all_y_pred[idx:endidx] = output idx += target.shape[0] # Print out the accuracy as a percentage n_correct = np.sum(all_targets == all_y_pred) print( f"Test accuracy over plaintext model: " f"{n_correct / len(test_loader) * 100:.2f}%" )
In [ ]:
Copied!
test_torch(fed_sec_agg_model, test_dataloader)
test_torch(fed_sec_agg_model, test_dataloader)

In our example we reach an accuracy over the plaintext model of 77.56%

3.2. Inference with Concrete-ML using the encrypted model¶

Using Zama's Concrete-ML library, we now show that a similar performance can be achieved by performing inference on the encrypted model.

We have therefore achieved fully secure, end-to-end encrypted training and inference.

In [ ]:
Copied!
def test_with_concrete(quantized_module, test_loader, use_sim):
    """Test a neural network that is quantized and compiled with Concrete ML."""

    # Casting the inputs into int64 is recommended
    all_y_pred = np.zeros((len(test_loader)), dtype=np.int64)
    all_targets = np.zeros((len(test_loader)), dtype=np.int64)

    # Iterate over the test batches and accumulate predictions and ground truth labels in a vector
    idx = 0
    for data, target in test_loader:
        data = data.numpy()
        target = target.numpy()

        fhe_mode = "simulate" if use_sim else "execute"

        # Quantize the inputs and cast to appropriate data type
        logits = torch.tensor(quantized_module.forward(data, fhe=fhe_mode), requires_grad=False)

        endidx = idx + target.shape[0]

        all_targets[idx:endidx] = target

        # Get the predicted class id and accumulate the predictions
        y_pred = torch.sigmoid(logits).argmax(1).numpy()
        all_y_pred[idx:endidx] = y_pred

        # Update the index
        idx += target.shape[0]
    n_correct = np.sum(all_targets == all_y_pred)
    print(
        f"Test accuracy over encrypted model: "
        f"{n_correct / len(test_loader) * 100:.2f}%"
    )
def test_with_concrete(quantized_module, test_loader, use_sim): """Test a neural network that is quantized and compiled with Concrete ML.""" # Casting the inputs into int64 is recommended all_y_pred = np.zeros((len(test_loader)), dtype=np.int64) all_targets = np.zeros((len(test_loader)), dtype=np.int64) # Iterate over the test batches and accumulate predictions and ground truth labels in a vector idx = 0 for data, target in test_loader: data = data.numpy() target = target.numpy() fhe_mode = "simulate" if use_sim else "execute" # Quantize the inputs and cast to appropriate data type logits = torch.tensor(quantized_module.forward(data, fhe=fhe_mode), requires_grad=False) endidx = idx + target.shape[0] all_targets[idx:endidx] = target # Get the predicted class id and accumulate the predictions y_pred = torch.sigmoid(logits).argmax(1).numpy() all_y_pred[idx:endidx] = y_pred # Update the index idx += target.shape[0] n_correct = np.sum(all_targets == all_y_pred) print( f"Test accuracy over encrypted model: " f"{n_correct / len(test_loader) * 100:.2f}%" )
In [ ]:
Copied!
# concrete ml is using the traceback, 
# while fed-biomed for logging reasons fixs it to 3, to use concrete-ml we reset to the default value
import sys
sys.tracebacklimit = 30
# concrete ml is using the traceback, # while fed-biomed for logging reasons fixs it to 3, to use concrete-ml we reset to the default value import sys sys.tracebacklimit = 30
In [ ]:
Copied!
n_bits = 6
compile_set = np.random.randint(0, 10, (100, 13)).astype(float)
q_module = compile_torch_model(fed_sec_agg_model, compile_set, rounding_threshold_bits=6)
n_bits = 6 compile_set = np.random.randint(0, 10, (100, 13)).astype(float) q_module = compile_torch_model(fed_sec_agg_model, compile_set, rounding_threshold_bits=6)
In [ ]:
Copied!
test_with_concrete(q_module, test_dataloader, True)
test_with_concrete(q_module, test_dataloader, True)

In our example we reach an accuracy over the encrypted model of 76.38%

The loss of accuracy due to encryption during inference is neglectable!

Download Notebook
  • Introduction
  • Install Concrete-ML
    • Note for MacOS users with ARM chips
  • Identify patients with heart disease
    • FLamby configuration - Download FedHeart
  • 1. Fed-BioMed
  • 2. Federated Learning Training with SecAgg
  • 3. Inference
    • 3.1. Inference with torch using the plaintext model
    • 3.2. Inference with Concrete-ML using the encrypted model
Address:

2004 Rte des Lucioles, 06902 Sophia Antipolis

E-mail:

fedbiomed _at_ inria _dot_ fr

Fed-BioMed © 2022