-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathoptim.py
77 lines (65 loc) · 2.83 KB
/
optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
""" Wrapper of optimizers in torch.optim for computation of exponential moving average of parameters"""
import torch
def build_ema_optimizer(optimizer_cls):
class Optimizer(optimizer_cls):
def __init__(self, *args, polyak=0.0, **kwargs):
if not 0.0 <= polyak <= 1.0:
raise ValueError("Invalid polyak decay rate: {}".format(polyak))
super().__init__(*args, **kwargs)
self.defaults['polyak'] = polyak
def step(self, closure=None):
super().step(closure)
# update exponential moving average after gradient update to parameters
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
# state initialization
if 'ema' not in state:
state['ema'] = torch.zeros_like(p.data)
# ema update
state['ema'] -= (1 - self.defaults['polyak']) * (state['ema'] - p.data)
def swap_ema(self):
""" substitute exponential moving average values into parameter values """
for group in self.param_groups:
for p in group['params']:
data = p.data
state = self.state[p]
p.data = state['ema']
state['ema'] = data
def __repr__(self):
s = super().__repr__()
return self.__class__.__mro__[1].__name__ + ' (\npolyak: {}\n'.format(self.defaults['polyak']) + s.partition('\n')[2]
return Optimizer
Adam = build_ema_optimizer(torch.optim.Adam)
RMSprop = build_ema_optimizer(torch.optim.RMSprop)
if __name__ == '__main__':
import copy
torch.manual_seed(0)
x = torch.randn(2,2)
y = torch.rand(2,2)
polyak = 0.9
_m = torch.nn.Linear(2,2)
for optim in [Adam, RMSprop]:
m = copy.deepcopy(_m)
o = optim(m.parameters(), lr=0.1, polyak=polyak)
print('Testing: ', optim.__name__)
print(o)
print('init loss {:.3f}'.format(torch.mean((m(x) - y)**2).item()))
p = torch.zeros_like(m.weight)
for i in range(5):
loss = torch.mean((m(x) - y)**2)
print('step {}: loss {:.3f}'.format(i, loss.item()))
o.zero_grad()
loss.backward()
o.step()
# manual compute ema
p -= (1 - polyak) * (p - m.weight.data)
print('loss: {:.3f}'.format(torch.mean((m(x) - y)**2).item()))
print('swapping ema values for params.')
o.swap_ema_and_parameters()
assert torch.allclose(p, m.weight)
print('loss: {:.3f}'.format(torch.mean((m(x) - y)**2).item()))
print('swapping params for ema values.')
o.swap_ema_and_parameters()
print('loss: {:.3f}'.format(torch.mean((m(x) - y)**2).item()))
print()