Adding new algorithms#
Suppose that we are adding the configuration for our new algorithm. New algorithm should be implemented as two classes for server and client.
Base classes#
Implementation of the new classes should be derived from the following two base classes:
Example: NewAlgo#
Here we give some simple example.
Core algorithm class#
We first create classes for the global and local updates in appfl/algorithm
:
Create two classes
NewAlgoServer
andNewAlgoClient
innewalgo.py
In
NewAlgoServer
, theupdate
function conducts a global update by averaging the local model parameters sent from multiple clientsIn
NewAlgoClient
, theupdate
function conducts a local update and send the resulting local model parameters to the server
This is an example code:
from .algorithm import BaseServer, BaseClient
class NewAlgoServer(BaseServer):
def __init__(self, weights, model, num_clients, device, **kwargs):
super(NewAlgoServer, self).__init__(weights, model, num_clients, device)
self.__dict__.update(kwargs)
# Any additional initialization
def update(self, local_states: OrderedDict):
# Implement new server update function
class NewAlgoClient(BaseClient):
def __init__(self, id, weight, model, dataloader, device, **kwargs):
super(NewAlgoClient, self).__init__(id, weight, model, dataloader, device)
self.__dict__.update(kwargs)
# Any additional initialization
def update(self):
# Implement new client update function
Configuration dataclass#
The new algorithm also needs to set up some configurations. This can be done by adding new dataclass under appfl.config.fed
.
Let’s say we add src/appfl/config/fed/newalgo.py
file to implement the dataclass as follows:
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf
@dataclass
class NewAlgo:
type: str = "newalgo"
servername: str = "NewAlgoServer"
clientname: str = "NewAlgoClient"
args: DictConfig = OmegaConf.create(
{
# add new arguments
}
)
Then, we need to add the following line to the main configuration file config.py
.
from .fed.new_algorithm import *
This is the main configuration class in src/appfl/config/config.py
.
Each algorithm, specified in Config.fed
, can be configured in the dataclasses at appfl.config.fed.*
.
1from dataclasses import dataclass, field
2from typing import Any, List, Dict, Optional
3from omegaconf import DictConfig, OmegaConf
4import os
5import sys
6
7from .fed.federated import *
8from .fed.fedasync import *
9from .fed.iceadmm import *
10from .fed.iiadmm import *
11
12@dataclass
13class Config:
14 fed: Any = field(default_factory=Federated)
15
16 # Compute device
17 device: str = "cpu"
18 device_server: str = "cpu"
19
20 # Number of training epochs
21 num_clients: int = 1
22
23 # Number of training epochs
24 num_epochs: int = 2
25
26 # Number of workers in DataLoader
27 num_workers: int = 0
28
29 # Train data batch info
30 batch_training: bool = True ## TODO: revisit
31 train_data_batch_size: int = 64
32 train_data_shuffle: bool = True
33
34 # Indication of whether to validate or not using testing data
35 validation: bool = True
36 test_data_batch_size: int = 64
37 test_data_shuffle: bool = False
38
39 # Checking data sanity
40 data_sanity: bool = False
41
42 # Reproducibility
43 reproduce: bool = True
44
45 # PCA on Trajectory
46 pca_dir: str = ""
47 params_start: int = 0
48 params_end: int = 49
49 ncomponents: int = 40
50
51 # Tensorboard
52 use_tensorboard: bool = False
53
54 # Loading models
55 load_model: bool = False
56 load_model_dirname: str = ""
57 load_model_filename: str = ""
58
59 # Saving models (server)
60 save_model: bool = False
61 save_model_dirname: str = ""
62 save_model_filename: str = ""
63 checkpoints_interval: int = 2
64
65 # Saving state_dict (clients)
66 save_model_state_dict: bool = False
67 send_final_model: bool = False
68
69 # Logging and recording outputs
70 output_dirname: str = "output"
71 output_filename: str = "result"
72
73 logginginfo: DictConfig = OmegaConf.create({})
74 summary_file: str = ""
75
76 # Personalization options
77 personalization: bool = False
78 p_layers: List[str] = field(default_factory=lambda: [])
79 config_name: str = ""
80
81 ## gRPC configutations ##
82
83 # 100 MB for gRPC maximum message size
84 max_message_size: int = 10485760
85 use_ssl: bool = False
86 use_authenticator: bool = False
87 authenticator: str = "Globus" # "Globus", "Naive"
88 uri: str = "localhost:50051"
89
90 operator: DictConfig = OmegaConf.create({"id": 1})
91 server: DictConfig = OmegaConf.create({
92 "id": 1,
93 "authenticator_kwargs": {
94 "is_fl_server": True,
95 "globus_group_id": "77c1c74b-a33b-11ed-8951-7b5a369c0a53",
96 },
97 "server_certificate_key": "default",
98 "server_certificate": "default",
99 "max_workers": 10,
100 })
101 client: DictConfig = OmegaConf.create({
102 "id": 1,
103 "root_certificates": "default",
104 "authenticator_kwargs": {
105 "is_fl_server": False,
106 },
107 })
108
109 # Lossy compression enabling
110 enable_compression: bool = False
111 lossy_compressor: str = "SZ2"
112 lossless_compressor: str = "blosc"
113
114 # Lossy compression path configuration
115 ext = ".dylib" if sys.platform.startswith("darwin") else ".so"
116 current_dir = os.path.dirname(os.path.realpath(__file__))
117 base_dir = os.path.abspath(os.path.join(current_dir, os.pardir, os.pardir, os.pardir))
118 compressor_sz2_path: str = os.path.join(base_dir, ".compressor/SZ/build/sz/libSZ" + ext)
119 compressor_sz3_path: str = os.path.join(base_dir, ".compressor/SZ3/build/tools/sz3c/libSZ3c" + ext)
120 compressor_szx_path: str = os.path.join(base_dir, ".compressor/SZx-main/build/lib/libSZx" + ext)
121
122 # Compressor parameters
123 error_bounding_mode: str = ""
124 error_bound: float = 0.0
125
126 # Default data type
127 flat_model_dtype: str = "np.float32"
128 param_cutoff: int = 1024
129
130 # Data readiness
131 dr_metrics: Optional[List[str]] = field(default_factory=lambda: [])
132
133
134@dataclass
135class GlobusComputeServerConfig:
136 device: str = "cpu"
137 output_dir: str = "./"
138 data_dir: str = "./"
139 s3_bucket: Any = None
140 s3_creds: str = ""
141
142
143@dataclass
144class GlobusComputeClientConfig:
145 name : str = ""
146 endpoint_id : str = ""
147 device : str = "cpu"
148 output_dir : str = "./output"
149 data_dir : str = "./datasets"
150 get_data : DictConfig = OmegaConf.create({})
151 data_pipeline: DictConfig = OmegaConf.create({})
152
153
154@dataclass
155class ExecutableFunc:
156 module: str = ""
157 call: str = ""
158 script_file: str = ""
159 source: str = ""
160
161
162@dataclass
163class ClientTask:
164 task_id: str = ""
165 task_name: str = ""
166 client_idx: int = ""
167 pending: bool = True
168 success: bool = False
169 start_time: float = -1
170 end_time: float = -1
171 log: Optional[Dict] = field(default_factory=dict)
172
173
174@dataclass
175class GlobusComputeConfig(Config):
176 get_data: ExecutableFunc = field(default_factory=ExecutableFunc)
177 get_model: ExecutableFunc = field(default_factory=ExecutableFunc)
178 get_loss: ExecutableFunc = field(default_factory=ExecutableFunc)
179 val_metric: ExecutableFunc = field(default_factory=ExecutableFunc)
180 clients: List[GlobusComputeClientConfig] = field(default_factory=list)
181 dataset: str = ""
182 loss: str = "CrossEntropy"
183 model_kwargs: Dict = field(default_factory=dict)
184 server: GlobusComputeServerConfig
185 logging_tasks: List = field(default_factory=list)
186 hf_model_arc: str = ""
187 hf_model_weights: str = ""
188
189 # Testing and validation params
190 client_do_validation: bool = True
191 client_do_testing: bool = True
192 server_do_validation: bool = True
193 server_do_testing: bool = True
194
195 # Testing and validation frequency
196 client_validation_step: int = 1
197 server_validation_step: int = 1
198
199 # Cloud storage
200 use_cloud_transfer: bool = True