FL Server over Secure RPC#

We demonstrate how to launch a gRPC server as a federated learning server with authentication. Consider only one client so that we can launch a server and a client (from another notebook) together.

[12]:
num_clients = 1

Import dependencies#

We put all the imports here. Our framework appfl is backboned by torch and its neural network model torch.nn. We also import torchvision to download the MNIST dataset. More importantly, we need to import appfl.run_grpc_server module.

[13]:
import numpy as np
import math
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import ToTensor

from appfl.config import Config
from appfl.misc.data import Dataset
import appfl.run_grpc_server as grpc_server
from omegaconf import OmegaConf, DictConfig

Test dataset#

The server can also hold test data to check the performance of the global model, and the test data needs to be wrapped in Dataset object. Note that the server does not need any training data.

[14]:
test_data_raw = torchvision.datasets.MNIST(
    "./_data", train=False, download=False, transform=ToTensor()
)
test_data_input = []
test_data_label = []
for idx in range(len(test_data_raw)):
    test_data_input.append(test_data_raw[idx][0].tolist())
    test_data_label.append(test_data_raw[idx][1])

test_dataset = Dataset(
    torch.FloatTensor(test_data_input), torch.tensor(test_data_label)
)

Model#

Users can define their own models by deriving torch.nn.Module. For example in this simulation, we define the following convolutional neural network.

[15]:
class CNN(nn.Module):
    def __init__(self, num_channel=1, num_classes=10, num_pixel=28):
        super().__init__()
        self.conv1 = nn.Conv2d(
            num_channel, 32, kernel_size=5, padding=0, stride=1, bias=True
        )
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=0, stride=1, bias=True)
        self.maxpool = nn.MaxPool2d(kernel_size=(2, 2))
        self.act = nn.ReLU(inplace=True)

        X = num_pixel
        X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)
        X = X / 2
        X = math.floor(1 + (X + 2 * 0 - 1 * (5 - 1) - 1) / 1)
        X = X / 2
        X = int(X)

        self.fc1 = nn.Linear(64 * X * X, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.maxpool(x)
        x = self.act(self.conv2(x))
        x = self.maxpool(x)
        x = torch.flatten(x, 1)
        x = self.act(self.fc1(x))
        x = self.fc2(x)
        return x

model = CNN()

Loss and metric#

We define the loss function

[16]:
loss_fn = torch.nn.CrossEntropyLoss()

and the validation metric for the training as well.

[ ]:
def accuracy(y_true, y_pred):
    '''
    y_true and y_pred are both of type np.ndarray
    y_true (N, d) where N is the size of the validation set, and d is the dimension of the label
    y_pred (N, D) where N is the size of the validation set, and D is the output dimension of the ML model
    '''
    if len(y_pred.shape) == 1:
        y_pred = np.round(y_pred)
    else:
        y_pred = y_pred.argmax(axis=1)
    return 100*np.sum(y_pred==y_true)/y_pred.shape[0]

Configurations#

We run the appfl training with the data and model defined above. A number of parameters can be easily set by changing the configuration values. We read the default configurations from appfl.config.Config class as a DictConfig object.

[ ]:
cfg: DictConfig = OmegaConf.structured(Config)
# print(OmegaConf.to_yaml(cfg))

Create secure SSL server and authenticator#

Secure SSL server requires both public certificate and private key for data encryption. We have provided a example pair of certificate and key for demonstration. It should be noted that in practice, you should never share your key to others and keep it secretly.

To use the provided certificate and key, we need to set the following. If the user would like to use his own certificate and key, just change the corresponding field to the file path.

[18]:
cfg.server.server_certificate="default"
cfg.server.server_certificate_key="default"

Then to use the NaiveAuthenticator, user needs to set the following as the NaiveAuthenticator does not take any argument.

[ ]:
cfg.server.authenticator="Naive"
cfg.server.authenticator_kwargs={}

Run with configurations#

For the server, we just run it by setting the number of global epochs to 5, and start the secure FL experiment.

[19]:
cfg.num_epochs = 5
grpc_server.run_server(cfg, model, loss_fn, num_clients, test_dataset, accuracy)
[Round:  001] Finished; all clients have sent their results.
[Round:  001] Finished; all clients have sent their results.
[Round:  001] Updating model weights
[Round:  001] Updating model weights
[Round:  001] Test set: Average loss: 0.3082, Accuracy: 90.95%, Best Accuracy: 90.95%
[Round:  001] Test set: Average loss: 0.3082, Accuracy: 90.95%, Best Accuracy: 90.95%
[Round:  002] Finished; all clients have sent their results.
[Round:  002] Finished; all clients have sent their results.
[Round:  002] Updating model weights
[Round:  002] Updating model weights
[Round:  002] Test set: Average loss: 0.1699, Accuracy: 94.94%, Best Accuracy: 94.94%
[Round:  002] Test set: Average loss: 0.1699, Accuracy: 94.94%, Best Accuracy: 94.94%
[Round:  003] Finished; all clients have sent their results.
[Round:  003] Finished; all clients have sent their results.
[Round:  003] Updating model weights
[Round:  003] Updating model weights
[Round:  003] Test set: Average loss: 0.1106, Accuracy: 96.73%, Best Accuracy: 96.73%
[Round:  003] Test set: Average loss: 0.1106, Accuracy: 96.73%, Best Accuracy: 96.73%
[Round:  004] Finished; all clients have sent their results.
[Round:  004] Finished; all clients have sent their results.
[Round:  004] Updating model weights
[Round:  004] Updating model weights
[Round:  004] Test set: Average loss: 0.0852, Accuracy: 97.58%, Best Accuracy: 97.58%
[Round:  004] Test set: Average loss: 0.0852, Accuracy: 97.58%, Best Accuracy: 97.58%
[Round:  005] Finished; all clients have sent their results.
[Round:  005] Finished; all clients have sent their results.
[Round:  005] Updating model weights
[Round:  005] Updating model weights
[Round:  005] Test set: Average loss: 0.0764, Accuracy: 97.77%, Best Accuracy: 97.77%
[Round:  005] Test set: Average loss: 0.0764, Accuracy: 97.77%, Best Accuracy: 97.77%