-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathvanilla_gan.py
108 lines (87 loc) · 3.43 KB
/
vanilla_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# coding: utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import numpy as np
import torch
from torch.autograd import Variable, grad
from torch.nn.init import xavier_normal
from torchvision import datasets, transforms
import torchvision.utils as vutils
def xavier_init(model):
for param in model.parameters():
if len(param.size()) == 2:
xavier_normal(param)
if __name__ == '__main__':
batch_size = 128
z_dim = 100
h_dim = 128
y_dim = 784
max_epochs = 1000
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor()
])),
batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor()
])),
batch_size=batch_size, shuffle=False, drop_last=True)
generator = torch.nn.Sequential(torch.nn.Linear(z_dim, h_dim),
torch.nn.Sigmoid(),
torch.nn.Linear(h_dim, y_dim),
torch.nn.Sigmoid())
discriminator = torch.nn.Sequential(torch.nn.Linear(y_dim, h_dim),
torch.nn.Sigmoid(),
torch.nn.Linear(h_dim, 1),
torch.nn.Sigmoid())
# Init weight matrices (xavier_normal)
xavier_init(generator)
xavier_init(discriminator)
opt_g = torch.optim.Adam(generator.parameters())
opt_d = torch.optim.Adam(discriminator.parameters())
criterion = torch.nn.BCELoss()
X = Variable(torch.FloatTensor(batch_size, y_dim))
z = Variable(torch.FloatTensor(batch_size, z_dim))
labels = Variable(torch.FloatTensor(batch_size))
# Train
for epoch in range(max_epochs):
for batch_idx, (data, target) in enumerate(train_loader):
X.data.copy_(data)
# Update discriminator
# train with real
discriminator.zero_grad()
pred_real = discriminator(X)
labels.data.fill_(1.0)
loss_d_real = criterion(pred_real, labels)
loss_d_real.backward()
# train with fake
z.data.normal_(0, 1)
fake = generator.forward(z).detach()
pred_fake = discriminator(fake)
labels.data.fill_(0.0)
loss_d_fake = criterion(pred_fake, labels)
loss_d_fake.backward()
loss_d = loss_d_real + loss_d_fake
opt_d.step()
# Update generator
generator.zero_grad()
z.data.normal_(0, 1)
gen = generator(z)
pred_gen = discriminator(gen)
labels.data.fill_(1.0)
loss_g = criterion(pred_gen, labels)
loss_g.backward()
opt_g.step()
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f'
% (epoch, max_epochs, batch_idx, len(train_loader),
loss_d.data[0], loss_g.data[0]))
if batch_idx % 100 == 0:
vutils.save_image(data,
'samples_vanilla/real_samples.png')
fake = generator(z)
vutils.save_image(gen.data.view(batch_size, 1, 28, 28),
'samples_vanilla/fake_samples_epoch_%03d.png' % epoch)