Adding new algorithms#

Suppose that we are adding the configuration for our new algorithm. New algorithm should be implemented as two classes for server and client.

Base classes#

Implementation of the new classes should be derived from the following two base classes:

Example: NewAlgo#

Here we give some simple example.

Core algorithm class#

We first create classes for the global and local updates in appfl/algorithm:

  • Create two classes NewAlgoServer and NewAlgoClient in newalgo.py

  • In NewAlgoServer, the update function conducts a global update by averaging the local model parameters sent from multiple clients

  • In NewAlgoClient, the update function conducts a local update and send the resulting local model parameters to the server

This is an example code:

Example code for src/appfl/algorithm/newalgo.py#
from .algorithm import BaseServer, BaseClient

class NewAlgoServer(BaseServer):
    def __init__(self, weights, model, num_clients, device, **kwargs):
        super(NewAlgoServer, self).__init__(weights, model, num_clients, device)
        self.__dict__.update(kwargs)
        # Any additional initialization

    def update(self, local_states: OrderedDict):
        # Implement new server update function

class NewAlgoClient(BaseClient):
    def __init__(self, id, weight, model, dataloader, device, **kwargs):
        super(NewAlgoClient, self).__init__(id, weight, model, dataloader, device)
        self.__dict__.update(kwargs)
        # Any additional initialization

    def update(self):
        # Implement new client update function

Configuration dataclass#

The new algorithm also needs to set up some configurations. This can be done by adding new dataclass under appfl.config.fed. Let’s say we add src/appfl/config/fed/newalgo.py file to implement the dataclass as follows:

Example code for src/appfl/config/fed/newalgo.py#
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf

@dataclass
class NewAlgo:
    type: str = "newalgo"
    servername: str = "NewAlgoServer"
    clientname: str = "NewAlgoClient"
    args: DictConfig = OmegaConf.create(
        {
            # add new arguments
        }
    )

Then, we need to add the following line to the main configuration file config.py.

from .fed.new_algorithm import *

This is the main configuration class in src/appfl/config/config.py. Each algorithm, specified in Config.fed, can be configured in the dataclasses at appfl.config.fed.*.

The main configuration class#
  1from dataclasses import dataclass, field
  2from typing import Any, List, Dict, Optional
  3from omegaconf import DictConfig, OmegaConf
  4import os
  5import sys
  6
  7from .fed.federated import *
  8from .fed.fedasync import *
  9from .fed.iceadmm import * 
 10from .fed.iiadmm import *
 11
 12@dataclass
 13class Config:
 14    fed: Any = field(default_factory=Federated)
 15
 16    # Compute device
 17    device: str = "cpu"
 18    device_server: str = "cpu"
 19
 20    # Number of training epochs
 21    num_clients: int = 1
 22
 23    # Number of training epochs
 24    num_epochs: int = 2
 25
 26    # Number of workers in DataLoader
 27    num_workers: int = 0
 28
 29    # Train data batch info
 30    batch_training: bool = True  ## TODO: revisit
 31    train_data_batch_size: int = 64
 32    train_data_shuffle: bool = True
 33
 34    # Indication of whether to validate or not using testing data
 35    validation: bool = True
 36    test_data_batch_size: int = 64
 37    test_data_shuffle: bool = False
 38
 39    # Checking data sanity
 40    data_sanity: bool = False
 41
 42    # Reproducibility
 43    reproduce: bool = True
 44
 45    # PCA on Trajectory
 46    pca_dir: str = ""
 47    params_start: int = 0
 48    params_end: int = 49
 49    ncomponents: int = 40
 50
 51    # Tensorboard
 52    use_tensorboard: bool = False
 53
 54    # Loading models
 55    load_model: bool = False
 56    load_model_dirname: str = ""
 57    load_model_filename: str = ""
 58
 59    # Saving models (server)
 60    save_model: bool = False
 61    save_model_dirname: str = ""
 62    save_model_filename: str = ""
 63    checkpoints_interval: int = 2
 64
 65    # Saving state_dict (clients)
 66    save_model_state_dict: bool = False
 67    send_final_model: bool = False
 68
 69    # Logging and recording outputs
 70    output_dirname: str = "output"
 71    output_filename: str = "result"
 72
 73    logginginfo: DictConfig = OmegaConf.create({})
 74    summary_file: str = ""
 75
 76    # Personalization options
 77    personalization: bool = False
 78    p_layers: List[str] = field(default_factory=lambda: [])
 79    config_name: str = ""
 80
 81    ## gRPC configutations ##
 82
 83    # 100 MB for gRPC maximum message size
 84    max_message_size: int = 10485760
 85    use_ssl: bool = False
 86    use_authenticator: bool = False
 87    authenticator: str = "Globus" # "Globus", "Naive"
 88    uri: str = "localhost:50051"
 89
 90    operator: DictConfig = OmegaConf.create({"id": 1})
 91    server: DictConfig = OmegaConf.create({
 92        "id": 1, 
 93        "authenticator_kwargs": {
 94            "is_fl_server": True,
 95            "globus_group_id": "77c1c74b-a33b-11ed-8951-7b5a369c0a53",
 96        },
 97        "server_certificate_key": "default",
 98        "server_certificate": "default",
 99        "max_workers": 10,
100    })
101    client: DictConfig = OmegaConf.create({
102        "id": 1,
103        "root_certificates": "default",
104        "authenticator_kwargs": {
105            "is_fl_server": False,
106        },
107    })
108
109    # Lossy compression enabling
110    enable_compression: bool = False
111    lossy_compressor: str = "SZ2"
112    lossless_compressor: str = "blosc"
113
114    # Lossy compression path configuration
115    ext = ".dylib" if sys.platform.startswith("darwin") else ".so"
116    current_dir = os.path.dirname(os.path.realpath(__file__))
117    base_dir = os.path.abspath(os.path.join(current_dir, os.pardir, os.pardir, os.pardir))
118    compressor_sz2_path: str = os.path.join(base_dir, ".compressor/SZ/build/sz/libSZ" + ext)
119    compressor_sz3_path: str = os.path.join(base_dir, ".compressor/SZ3/build/tools/sz3c/libSZ3c" + ext)
120    compressor_szx_path: str = os.path.join(base_dir, ".compressor/SZx-main/build/lib/libSZx" + ext)
121
122    # Compressor parameters
123    error_bounding_mode: str = ""
124    error_bound: float = 0.0
125
126    # Default data type
127    flat_model_dtype: str = "np.float32"
128    param_cutoff: int = 1024
129
130    # Data readiness
131    dr_metrics: Optional[List[str]] = field(default_factory=lambda: [])
132
133
134@dataclass
135class GlobusComputeServerConfig:
136    device: str = "cpu"
137    output_dir: str = "./"
138    data_dir: str = "./"
139    s3_bucket: Any = None
140    s3_creds: str = ""
141
142
143@dataclass
144class GlobusComputeClientConfig:
145    name        : str = ""
146    endpoint_id : str = ""
147    device      : str = "cpu"
148    output_dir  : str = "./output"
149    data_dir    : str = "./datasets"
150    get_data    :  DictConfig = OmegaConf.create({})
151    data_pipeline: DictConfig = OmegaConf.create({})
152
153
154@dataclass
155class ExecutableFunc:
156    module: str = ""
157    call: str = ""
158    script_file: str = ""
159    source: str = ""
160
161
162@dataclass
163class ClientTask:
164    task_id: str = ""
165    task_name: str = ""
166    client_idx: int = ""
167    pending: bool = True
168    success: bool = False
169    start_time: float = -1
170    end_time: float = -1
171    log: Optional[Dict] = field(default_factory=dict)
172
173
174@dataclass
175class GlobusComputeConfig(Config):
176    get_data: ExecutableFunc = field(default_factory=ExecutableFunc)
177    get_model: ExecutableFunc = field(default_factory=ExecutableFunc)
178    get_loss: ExecutableFunc = field(default_factory=ExecutableFunc)
179    val_metric: ExecutableFunc = field(default_factory=ExecutableFunc)
180    clients: List[GlobusComputeClientConfig] = field(default_factory=list)
181    dataset: str = ""
182    loss: str = "CrossEntropy"
183    model_kwargs: Dict = field(default_factory=dict)
184    server: GlobusComputeServerConfig
185    logging_tasks: List = field(default_factory=list)
186    hf_model_arc: str = ""
187    hf_model_weights: str = ""
188
189    # Testing and validation params
190    client_do_validation: bool = True
191    client_do_testing: bool = True
192    server_do_validation: bool = True
193    server_do_testing: bool = True
194
195    # Testing and validation frequency
196    client_validation_step: int = 1
197    server_validation_step: int = 1
198
199    # Cloud storage
200    use_cloud_transfer: bool = True