How to add new algorithms
Suppose that we are adding the configuration for our new algorithm. New algorithm should be implemented as two classes for server and client. Implementation of the new classes should be derived from the following two base classes:
Example: NewAlgo
Here we give some simple example.
Core algorithm class
We first create classes for the global and local updates in appfl/algorithm
:
See two classes
NewAlgoServer
andNewAlgoClient
innewalgo.py
In
NewAlgoServer
, theupdate
function conducts a global update by averaging the local model parameters sent from multiple clientsIn
NewAlgoClient
, theupdate
function conducts a local update and send the resulting local model parameters to the server
This is an example code:
src/appfl/algorithm/newalgo.py
from .algorithm import BaseServer, BaseClient
class NewAlgoServer(BaseServer):
def __init__(self, weights, model, num_clients, device, **kwargs):
super(NewAlgoServer, self).__init__(weights, model, num_clients, device)
self.__dict__.update(kwargs)
# Any additional initialization
def update(self, local_states: OrderedDict):
# Implement new server update function
class NewAlgoClient(BaseClient):
def __init__(self, id, weight, model, dataloader, device, **kwargs):
super(NewAlgoClient, self).__init__(id, weight, model, dataloader, device)
self.__dict__.update(kwargs)
# Any additional initialization
def update(self):
# Implement new client update function
Configuration dataclass
The new algorithm also needs to set up some configurations. This can be done by adding new dataclass under appfl.config.fed
.
Let’s say we add src/appfl/config/fed/newalgo.py
file to implement the dataclass as follows:
src/appfl/config/fed/newalgo.py
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf
@dataclass
class NewAlgo:
type: str = "newalgo"
servername: str = "NewAlgoServer"
clientname: str = "NewAlgoClient"
args: DictConfig = OmegaConf.create(
{
# add new arguments
}
)
Then, we need to add the following line to the main configuration file config.py
.
from .fed.new_algorithm import *
This is the main configuration class in src/appfl/config/config.py
.
Each algorithm, specified in Config.fed
, can be configured in the dataclasses at appfl.config.fed.*
.
1from dataclasses import dataclass, field
2from typing import Any, List, Dict, Optional
3from omegaconf import DictConfig, OmegaConf
4
5
6from .fed.federated import *
7from .fed.fedasync import *
8from .fed.iceadmm import * ## TODO: combine iceadmm and iiadmm under the name of ADMM.
9from .fed.iiadmm import *
10
11
12@dataclass
13class Config:
14 fed: Any = field(default_factory=Federated)
15
16 # Compute device
17 device: str = "cpu"
18 device_server: str = "cpu"
19
20 # Number of training epochs
21 num_clients: int = 1
22
23 # Number of training epochs
24 num_epochs: int = 2
25
26 # Number of workers in DataLoader
27 num_workers: int = 0
28
29 # Train data batch info
30 batch_training: bool = True ## TODO: revisit
31 train_data_batch_size: int = 64
32 train_data_shuffle: bool = True
33
34 # Indication of whether to validate or not using testing data
35 validation: bool = True
36 test_data_batch_size: int = 64
37 test_data_shuffle: bool = False
38
39 # Checking data sanity
40 data_sanity: bool = False
41
42 # Reproducibility
43 reproduce: bool = True
44
45 # PCA on Trajectory
46 pca_dir: str = ""
47 params_start: int = 0
48 params_end: int = 49
49 ncomponents: int = 40
50
51 # Tensorboard
52 use_tensorboard: bool = False
53
54 # Loading models
55 load_model: bool = False
56 load_model_dirname: str = ""
57 load_model_filename: str = ""
58
59 # Saving models (server)
60 save_model: bool = False
61 save_model_dirname: str = ""
62 save_model_filename: str = ""
63 checkpoints_interval: int = 2
64
65 # Saving state_dict (clients)
66 save_model_state_dict: bool = False
67 send_final_model: bool = False
68
69 # Logging and recording outputs
70 output_dirname: str = "output"
71 output_filename: str = "result"
72
73 logginginfo: DictConfig = OmegaConf.create({})
74 summary_file: str = ""
75
76 # Personalization options
77 personalization: bool = False
78 p_layers: List[str] = field(default_factory=lambda: [])
79 config_name: str = ""
80
81 #
82 # gRPC configutations
83 #
84
85 # 100 MB for gRPC maximum message size
86 max_message_size: int = 104857600
87
88 operator: DictConfig = OmegaConf.create({"id": 1})
89 server: DictConfig = OmegaConf.create(
90 {"id": 1, "host": "localhost", "port": 50051, "use_tls": False, "api_key": None}
91 )
92 client: DictConfig = OmegaConf.create({"id": 1})
93
94@dataclass
95class GlobusComputeServerConfig:
96 device : str = "cpu"
97 output_dir : str = "./"
98 data_dir : str = "./"
99 s3_bucket : Any = None
100 s3_creds : str = ""
101
102@dataclass
103class GlobusComputeClientConfig:
104 name : str = ""
105 endpoint_id : str = ""
106 device : str = "cpu"
107 output_dir : str = "./"
108 data_dir : str = "./"
109 get_data : DictConfig = OmegaConf.create({})
110 data_pipeline: DictConfig = OmegaConf.create({})
111
112@dataclass
113class ExecutableFunc:
114 module : str = ""
115 call : str = ""
116 script_file : str = ""
117 source : str = ""
118
119@dataclass
120class ClientTask:
121 task_id : str = ""
122 task_name : str = ""
123 client_idx : int = ""
124 pending : bool = True
125 success : bool = False
126 start_time : float= -1
127 end_time : float= -1
128 log : Optional[Dict] = field(default_factory=dict)
129
130@dataclass
131class GlobusComputeConfig(Config):
132 get_data : ExecutableFunc = field(default_factory=ExecutableFunc)
133 get_model : ExecutableFunc = field(default_factory=ExecutableFunc)
134 get_loss : ExecutableFunc = field(default_factory=ExecutableFunc)
135 val_metric : ExecutableFunc = field(default_factory=ExecutableFunc)
136 clients : List[GlobusComputeClientConfig] = field(default_factory=list)
137 dataset : str = ""
138 loss : str = "CrossEntropy"
139 model_kwargs : Dict = field(default_factory=dict)
140 server : GlobusComputeServerConfig
141 logging_tasks: List = field(default_factory=list)
142 hf_model_arc : str = ""
143 hf_model_weights: str = ""
144
145 # Testing and validation params
146 client_do_validation: bool = True
147 client_do_testing : bool = True
148 server_do_validation: bool = True
149 server_do_testing : bool = True
150
151 # Testing and validation frequency
152 client_validation_step: int = 1
153 server_validation_step: int = 1
154
155 # Cloud storage
156 use_cloud_transfer: bool = True