-
-
Notifications
You must be signed in to change notification settings - Fork 89
/
Copy pathfedavg.py
335 lines (294 loc) · 12.8 KB
/
fedavg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
from collections import OrderedDict
from copy import deepcopy
from typing import Any
import torch
from omegaconf import DictConfig
from torch.utils.data import DataLoader, Subset
from data.utils.datasets import BaseDataset
from src.utils.functional import evaluate_model, get_optimal_cuda_device
from src.utils.metrics import Metrics
from src.utils.models import DecoupledModel
class FedAvgClient:
def __init__(
self,
model: DecoupledModel,
optimizer_cls: type[torch.optim.Optimizer],
lr_scheduler_cls: type[torch.optim.lr_scheduler.LRScheduler],
args: DictConfig,
dataset: BaseDataset,
data_indices: list,
device: torch.device | None,
return_diff: bool,
):
self.client_id: int = None
self.args = args
if device is None:
self.device = get_optimal_cuda_device(use_cuda=self.args.common.use_cuda)
else:
self.device = device
self.dataset = dataset
self.model = model.to(self.device)
self.regular_model_params: OrderedDict[str, torch.Tensor]
self.personal_params_name: list[str] = []
self.regular_params_name = list(key for key, _ in self.model.named_parameters())
if self.args.common.buffers == "local":
self.personal_params_name.extend(
[name for name, _ in self.model.named_buffers()]
)
elif self.args.common.buffers == "drop":
self.init_buffers = deepcopy(OrderedDict(self.model.named_buffers()))
self.optimizer = optimizer_cls(params=self.model.parameters())
self.init_optimizer_state = deepcopy(self.optimizer.state_dict())
self.lr_scheduler: torch.optim.lr_scheduler.LRScheduler = None
self.init_lr_scheduler_state: dict = None
self.lr_scheduler_cls = None
if lr_scheduler_cls is not None:
self.lr_scheduler_cls = lr_scheduler_cls
self.lr_scheduler = self.lr_scheduler_cls(optimizer=self.optimizer)
self.init_lr_scheduler_state = deepcopy(self.lr_scheduler.state_dict())
# [{"train": [...], "val": [...], "test": [...]}, ...]
self.data_indices = data_indices
# Please don't bother with the [0], which is only for avoiding raising runtime error by setting Subset(indices=[]) with `DataLoader(shuffle=True)`
self.trainset = Subset(self.dataset, indices=[0])
self.valset = Subset(self.dataset, indices=[])
self.testset = Subset(self.dataset, indices=[])
self.trainloader = DataLoader(
self.trainset, batch_size=self.args.common.batch_size, shuffle=True
)
self.valloader = DataLoader(self.valset, batch_size=self.args.common.batch_size)
self.testloader = DataLoader(
self.testset, batch_size=self.args.common.batch_size
)
self.testing = False
self.local_epoch = self.args.common.local_epoch
self.criterion = torch.nn.CrossEntropyLoss().to(self.device)
self.eval_results = {}
self.return_diff = return_diff
def load_data_indices(self):
"""This function is for loading data indices for No.`self.client_id`
client."""
self.trainset.indices = self.data_indices[self.client_id]["train"]
self.valset.indices = self.data_indices[self.client_id]["val"]
self.testset.indices = self.data_indices[self.client_id]["test"]
def train_with_eval(self):
"""Wraps `fit()` with `evaluate()` and collect model evaluation
results.
A model evaluation results dict: {
`before`: {...}
`after`: {...}
`message`: "..."
}
`before` means pre-local-training.
`after` means post-local-training
"""
eval_results = {
"before": {"train": Metrics(), "val": Metrics(), "test": Metrics()},
"after": {"train": Metrics(), "val": Metrics(), "test": Metrics()},
}
eval_results["before"] = self.evaluate()
if self.local_epoch > 0:
self.fit()
eval_results["after"] = self.evaluate()
eval_msg = []
for split, color, flag, subset in [
["train", "yellow", self.args.common.test.client.train, self.trainset],
["val", "green", self.args.common.test.client.val, self.valset],
["test", "cyan", self.args.common.test.client.test, self.testset],
]:
if len(subset) > 0 and flag:
eval_msg.append(
f"client [{self.client_id}]\t"
f"[{color}]({split}set)[/{color}]\t"
f"[red]loss: {eval_results['before'][split].loss:.4f} -> "
f"{eval_results['after'][split].loss:.4f}\t[/red]"
f"[blue]accuracy: {eval_results['before'][split].accuracy:.2f}% -> {eval_results['after'][split].accuracy:.2f}%[/blue]"
)
eval_results["message"] = eval_msg
self.eval_results = eval_results
def set_parameters(self, package: dict[str, Any]):
self.client_id = package["client_id"]
self.local_epoch = package["local_epoch"]
self.load_data_indices()
if (
package["optimizer_state"]
and not self.args.common.reset_optimizer_on_global_epoch
):
self.optimizer.load_state_dict(package["optimizer_state"])
else:
self.optimizer.load_state_dict(self.init_optimizer_state)
if self.lr_scheduler is not None:
if package["lr_scheduler_state"]:
self.lr_scheduler.load_state_dict(package["lr_scheduler_state"])
else:
self.lr_scheduler.load_state_dict(self.init_lr_scheduler_state)
self.model.load_state_dict(package["regular_model_params"], strict=False)
self.model.load_state_dict(package["personal_model_params"], strict=False)
if self.args.common.buffers == "drop":
self.model.load_state_dict(self.init_buffers, strict=False)
if self.return_diff:
model_params = self.model.state_dict()
self.regular_model_params = OrderedDict(
(key, model_params[key].clone().cpu())
for key in self.regular_params_name
)
def train(self, server_package: dict[str, Any]) -> dict:
self.set_parameters(server_package)
self.train_with_eval()
client_package = self.package()
return client_package
def package(self):
"""Package data that client needs to transmit to the server. You can
override this function and add more parameters.
Returns:
A dict: {
`weight`: Client weight. Defaults to the size of client training set.
`regular_model_params`: Client model parameters that will join parameter aggregation.
`model_params_diff`: The parameter difference between the client trained and the global. `diff = global - trained`.
`eval_results`: Client model evaluation results.
`personal_model_params`: Client model parameters that absent to parameter aggregation.
`optimzier_state`: Client optimizer's state dict.
`lr_scheduler_state`: Client learning rate scheduler's state dict.
}
"""
model_params = self.model.state_dict()
client_package = dict(
weight=len(self.trainset),
eval_results=self.eval_results,
regular_model_params={
key: model_params[key].clone().cpu() for key in self.regular_params_name
},
personal_model_params={
key: model_params[key].clone().cpu()
for key in self.personal_params_name
},
optimizer_state=deepcopy(self.optimizer.state_dict()),
lr_scheduler_state=(
{}
if self.lr_scheduler is None
else deepcopy(self.lr_scheduler.state_dict())
),
)
if self.return_diff:
client_package["model_params_diff"] = {
key: param_old - param_new
for (key, param_new), param_old in zip(
client_package["regular_model_params"].items(),
self.regular_model_params.values(),
)
}
client_package.pop("regular_model_params")
return client_package
def fit(self):
self.model.train()
self.dataset.train()
for _ in range(self.local_epoch):
for x, y in self.trainloader:
# When the current batch size is 1, the batchNorm2d modules in the model would raise error.
# So the latent size 1 data batches are discarded.
if len(x) <= 1:
continue
x, y = x.to(self.device), y.to(self.device)
logit = self.model(x)
loss = self.criterion(logit, y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.lr_scheduler is not None:
self.lr_scheduler.step()
@torch.no_grad()
def evaluate(self, model: torch.nn.Module = None) -> dict[str, Metrics]:
"""Evaluating client model.
Args:
model: Used model. Defaults to None, which will fallback to `self.model`.
Returns:
A evalution results dict: {
`train`: results on client training set.
`val`: results on client validation set.
`test`: results on client test set.
}
"""
target_model = self.model if model is None else model
target_model.eval()
self.dataset.eval()
train_metrics = Metrics()
val_metrics = Metrics()
test_metrics = Metrics()
criterion = torch.nn.CrossEntropyLoss(reduction="sum")
if (
len(self.testset) > 0
and (self.testing or self.args.common.client_side_evaluation)
and self.args.common.test.client.test
):
test_metrics = evaluate_model(
model=target_model,
dataloader=self.testloader,
criterion=criterion,
device=self.device,
)
if (
len(self.valset) > 0
and (self.testing or self.args.common.client_side_evaluation)
and self.args.common.test.client.val
):
val_metrics = evaluate_model(
model=target_model,
dataloader=self.valloader,
criterion=criterion,
device=self.device,
)
if (
len(self.trainset) > 0
and (self.testing or self.args.common.client_side_evaluation)
and self.args.common.test.client.train
):
train_metrics = evaluate_model(
model=target_model,
dataloader=self.trainloader,
criterion=criterion,
device=self.device,
)
return {"train": train_metrics, "val": val_metrics, "test": test_metrics}
def test(self, server_package: dict[str, Any]) -> dict[str, dict[str, Metrics]]:
"""Test client model. If `finetune_epoch > 0`, `finetune()` will be
activated.
Args:
server_package: Parameter package.
Returns:
A model evaluation results dict : {
`before`: {...}
`after`: {...}
`message`: "..."
}
`before` means pre-local-training.
`after` means post-local-training
"""
self.testing = True
self.set_parameters(server_package)
results = {
"before": {"train": Metrics(), "val": Metrics(), "test": Metrics()},
"after": {"train": Metrics(), "val": Metrics(), "test": Metrics()},
}
results["before"] = self.evaluate()
if self.args.common.test.client.finetune_epoch > 0:
frz_params_dict = deepcopy(self.model.state_dict())
self.finetune()
results["after"] = self.evaluate()
self.model.load_state_dict(frz_params_dict)
self.testing = False
return results
def finetune(self):
"""Client model finetuning.
This function will only be activated in `test()`
"""
self.model.train()
self.dataset.train()
for _ in range(self.args.common.test.client.finetune_epoch):
for x, y in self.trainloader:
if len(x) <= 1:
continue
x, y = x.to(self.device), y.to(self.device)
logit = self.model(x)
loss = self.criterion(logit, y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()