How to add new algorithms

Suppose that we are adding the configuration for our new algorithm. New algorithm should be implemented as two classes for server and client. Implementation of the new classes should be derived from the following two base classes:

Example: NewAlgo

Here we give some simple example.

Core algorithm class

We first create classes for the global and local updates in appfl/algorithm:

  • See two classes NewAlgoServer and NewAlgoClient in newalgo.py

  • In NewAlgoServer, the update function conducts a global update by averaging the local model parameters sent from multiple clients

  • In NewAlgoClient, the update function conducts a local update and send the resulting local model parameters to the server

This is an example code:

Example code for src/appfl/algorithm/newalgo.py
from .algorithm import BaseServer, BaseClient

class NewAlgoServer(BaseServer):
    def __init__(self, weights, model, num_clients, device, **kwargs):
        super(NewAlgoServer, self).__init__(weights, model, num_clients, device)
        self.__dict__.update(kwargs)
        # Any additional initialization

    def update(self, local_states: OrderedDict):
        # Implement new server update function

class NewAlgoClient(BaseClient):
    def __init__(self, id, weight, model, dataloader, device, **kwargs):
        super(NewAlgoClient, self).__init__(id, weight, model, dataloader, device)
        self.__dict__.update(kwargs)
        # Any additional initialization

    def update(self):
        # Implement new client update function

Configuration dataclass

The new algorithm also needs to set up some configurations. This can be done by adding new dataclass under appfl.config.fed. Let’s say we add src/appfl/config/fed/newalgo.py file to implement the dataclass as follows:

Example code for src/appfl/config/fed/newalgo.py
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf

@dataclass
class NewAlgo:
    type: str = "newalgo"
    servername: str = "NewAlgoServer"
    clientname: str = "NewAlgoClient"
    args: DictConfig = OmegaConf.create(
        {
            # add new arguments
        }
    )

Then, we need to add the following line to the main configuration file config.py.

from .fed.new_algorithm import *

This is the main configuration class in src/appfl/config/config.py. Each algorithm, specified in Config.fed, can be configured in the dataclasses at appfl.config.fed.*.

The main configuration class
  1from dataclasses import dataclass, field
  2from typing import Any, List, Dict, Optional
  3from omegaconf import DictConfig, OmegaConf
  4
  5
  6from .fed.federated import *
  7from .fed.fedasync import *
  8from .fed.iceadmm import *  ## TODO: combine iceadmm and iiadmm under the name of ADMM.
  9from .fed.iiadmm import *
 10
 11
 12@dataclass
 13class Config:
 14    fed: Any = field(default_factory=Federated)
 15
 16    # Compute device
 17    device: str = "cpu"
 18    device_server: str = "cpu"
 19
 20    # Number of training epochs
 21    num_clients: int = 1
 22
 23    # Number of training epochs
 24    num_epochs: int = 2
 25
 26    # Number of workers in DataLoader
 27    num_workers: int = 0
 28
 29    # Train data batch info
 30    batch_training: bool = True  ## TODO: revisit
 31    train_data_batch_size: int = 64
 32    train_data_shuffle: bool = True
 33
 34    # Indication of whether to validate or not using testing data
 35    validation: bool = True
 36    test_data_batch_size: int = 64
 37    test_data_shuffle: bool = False
 38
 39    # Checking data sanity
 40    data_sanity: bool = False
 41
 42    # Reproducibility
 43    reproduce: bool = True
 44
 45    # PCA on Trajectory
 46    pca_dir: str = ""
 47    params_start: int = 0
 48    params_end: int = 49
 49    ncomponents: int = 40
 50
 51    # Tensorboard
 52    use_tensorboard: bool = False
 53
 54    # Loading models
 55    load_model: bool = False
 56    load_model_dirname: str = ""
 57    load_model_filename: str = ""
 58
 59    # Saving models (server)
 60    save_model: bool = False
 61    save_model_dirname: str = ""
 62    save_model_filename: str = ""
 63    checkpoints_interval: int = 2
 64
 65    # Saving state_dict (clients)
 66    save_model_state_dict: bool = False
 67    send_final_model: bool = False
 68
 69    # Logging and recording outputs
 70    output_dirname: str = "output"
 71    output_filename: str = "result"
 72
 73    logginginfo: DictConfig = OmegaConf.create({})
 74    summary_file: str = ""
 75    
 76    # Personalization options
 77    personalization: bool = False
 78    p_layers: List[str] = field(default_factory=lambda: [])
 79    config_name: str = ""
 80
 81    #
 82    # gRPC configutations
 83    #
 84
 85    # 100 MB for gRPC maximum message size
 86    max_message_size: int = 104857600
 87
 88    operator: DictConfig = OmegaConf.create({"id": 1})
 89    server: DictConfig = OmegaConf.create(
 90        {"id": 1, "host": "localhost", "port": 50051, "use_tls": False, "api_key": None}
 91    )
 92    client: DictConfig = OmegaConf.create({"id": 1})
 93
 94@dataclass 
 95class GlobusComputeServerConfig:
 96    device      : str = "cpu"
 97    output_dir  : str = "./"
 98    data_dir    : str = "./"
 99    s3_bucket   : Any = None
100    s3_creds    : str = ""
101
102@dataclass
103class GlobusComputeClientConfig:
104    name        : str = ""
105    endpoint_id : str = ""
106    device      : str = "cpu"
107    output_dir  : str = "./"
108    data_dir    : str = "./"
109    get_data    :  DictConfig = OmegaConf.create({})
110    data_pipeline: DictConfig = OmegaConf.create({})
111
112@dataclass
113class ExecutableFunc:
114    module       : str = ""
115    call         : str = ""
116    script_file  : str = ""
117    source       : str = ""
118
119@dataclass
120class ClientTask:
121    task_id      : str  = ""
122    task_name    : str  = ""
123    client_idx   : int  = ""
124    pending      : bool = True
125    success      : bool = False
126    start_time   : float= -1
127    end_time     : float= -1
128    log          : Optional[Dict] = field(default_factory=dict)
129
130@dataclass
131class GlobusComputeConfig(Config):
132    get_data     : ExecutableFunc = field(default_factory=ExecutableFunc)
133    get_model    : ExecutableFunc = field(default_factory=ExecutableFunc)
134    get_loss     : ExecutableFunc = field(default_factory=ExecutableFunc)
135    val_metric   : ExecutableFunc = field(default_factory=ExecutableFunc)
136    clients      : List[GlobusComputeClientConfig] = field(default_factory=list)
137    dataset      : str  = ""
138    loss         : str  = "CrossEntropy"
139    model_kwargs : Dict = field(default_factory=dict)
140    server       : GlobusComputeServerConfig
141    logging_tasks: List = field(default_factory=list) 
142    hf_model_arc : str  = ""
143    hf_model_weights: str  = ""
144    
145    # Testing and validation params
146    client_do_validation: bool = True
147    client_do_testing   : bool = True
148    server_do_validation: bool = True
149    server_do_testing   : bool = True
150    
151    # Testing and validation frequency
152    client_validation_step: int = 1
153    server_validation_step: int = 1
154
155    # Cloud storage
156    use_cloud_transfer: bool = True