Skip to content

Commit 8c4dd99

Browse files
committed
vanilla gan added'
1 parent 2fe796b commit 8c4dd99

File tree

2 files changed

+60
-112
lines changed

2 files changed

+60
-112
lines changed

tutorials/10 - Generative Adversarial Network/main-gpu.py

Lines changed: 29 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,51 @@
11
import torch
22
import torchvision
33
import torch.nn as nn
4+
import torch.nn.functional as F
45
import torchvision.datasets as dsets
56
import torchvision.transforms as transforms
67
from torch.autograd import Variable
78

89
# Image Preprocessing
910
transform = transforms.Compose([
10-
transforms.Scale(36),
11-
transforms.RandomCrop(32),
1211
transforms.ToTensor(),
1312
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1413

15-
# CIFAR-10 Dataset
16-
train_dataset = dsets.CIFAR10(root='../data/',
17-
train=True,
18-
transform=transform,
19-
download=True)
14+
# MNIST Dataset
15+
train_dataset = dsets.MNIST(root='../data/',
16+
train=True,
17+
transform=transform,
18+
download=True)
2019

2120
# Data Loader (Input Pipeline)
2221
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
2322
batch_size=100,
2423
shuffle=True)
2524

26-
# 4x4 Convolution
27-
def conv4x4(in_channels, out_channels, stride):
28-
return nn.Conv2d(in_channels, out_channels, kernel_size=4,
29-
stride=stride, padding=1, bias=False)
30-
3125
# Discriminator Model
3226
class Discriminator(nn.Module):
3327
def __init__(self):
3428
super(Discriminator, self).__init__()
35-
self.model = nn.Sequential(
36-
conv4x4(3, 16, 2),
37-
nn.LeakyReLU(0.2, inplace=True),
38-
conv4x4(16, 32, 2),
39-
nn.BatchNorm2d(32),
40-
nn.LeakyReLU(0.2, inplace=True),
41-
conv4x4(32, 64, 2),
42-
nn.BatchNorm2d(64),
43-
nn.LeakyReLU(0.2, inplace=True),
44-
nn.Conv2d(64, 1, kernel_size=4),
45-
nn.Sigmoid())
46-
29+
self.fc1 = nn.Linear(784, 256)
30+
self.fc2 = nn.Linear(256, 1)
31+
4732
def forward(self, x):
48-
out = self.model(x)
49-
out = out.view(out.size(0), -1)
33+
h = F.relu(self.fc1(x))
34+
out = F.sigmoid(self.fc2(h))
5035
return out
5136

52-
# 4x4 Transpose convolution
53-
def conv_transpose4x4(in_channels, out_channels, stride=1, padding=1, bias=False):
54-
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4,
55-
stride=stride, padding=padding, bias=bias)
56-
5737
# Generator Model
5838
class Generator(nn.Module):
5939
def __init__(self):
6040
super(Generator, self).__init__()
61-
self.model = nn.Sequential(
62-
conv_transpose4x4(128, 64, padding=0),
63-
nn.BatchNorm2d(64),
64-
nn.ReLU(inplace=True),
65-
conv_transpose4x4(64, 32, 2),
66-
nn.BatchNorm2d(32),
67-
nn.ReLU(inplace=True),
68-
conv_transpose4x4(32, 16, 2),
69-
nn.BatchNorm2d(16),
70-
nn.ReLU(inplace=True),
71-
conv_transpose4x4(16, 3, 2, bias=True),
72-
nn.Tanh())
73-
41+
self.fc1 = nn.Linear(128, 256)
42+
self.fc2 = nn.Linear(256, 512)
43+
self.fc3 = nn.Linear(512, 784)
44+
7445
def forward(self, x):
75-
x = x.view(x.size(0), 128, 1, 1)
76-
out = self.model(x)
46+
h = F.leaky_relu(self.fc1(x))
47+
h = F.leaky_relu(self.fc2(h))
48+
out = F.tanh(self.fc3(h))
7749
return out
7850

7951
discriminator = Discriminator()
@@ -83,13 +55,14 @@ def forward(self, x):
8355

8456
# Loss and Optimizer
8557
criterion = nn.BCELoss()
86-
lr = 0.002
87-
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
88-
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
58+
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0005)
59+
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0005)
8960

9061
# Training
91-
for epoch in range(50):
62+
for epoch in range(200):
9263
for i, (images, _) in enumerate(train_loader):
64+
# Build mini-batch dataset
65+
images = images.view(images.size(0), -1)
9366
images = Variable(images.cuda())
9467
real_labels = Variable(torch.ones(images.size(0))).cuda()
9568
fake_labels = Variable(torch.zeros(images.size(0))).cuda()
@@ -119,15 +92,16 @@ def forward(self, x):
11992
g_loss.backward()
12093
g_optimizer.step()
12194

122-
if (i+1) % 100 == 0:
95+
if (i+1) % 300 == 0:
12396
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
12497
'D(x): %.2f, D(G(z)): %.2f'
125-
%(epoch, 50, i+1, 500, d_loss.data[0], g_loss.data[0],
98+
%(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
12699
real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
127100

128-
# Save the sampled images
129-
torchvision.utils.save_image(fake_images.data,
130-
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
101+
# Save the sampled images
102+
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
103+
torchvision.utils.save_image(fake_images.data,
104+
'./data2/fake_samples_%d.png' %epoch+1)
131105

132106
# Save the Models
133107
torch.save(generator.state_dict(), './generator.pkl')
Lines changed: 31 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,51 @@
11
import torch
22
import torchvision
33
import torch.nn as nn
4+
import torch.nn.functional as F
45
import torchvision.datasets as dsets
56
import torchvision.transforms as transforms
67
from torch.autograd import Variable
78

89
# Image Preprocessing
910
transform = transforms.Compose([
10-
transforms.Scale(36),
11-
transforms.RandomCrop(32),
1211
transforms.ToTensor(),
1312
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1413

15-
# CIFAR-10 Dataset
16-
train_dataset = dsets.CIFAR10(root='../data/',
17-
train=True,
18-
transform=transform,
19-
download=True)
14+
# MNIST Dataset
15+
train_dataset = dsets.MNIST(root='../data/',
16+
train=True,
17+
transform=transform,
18+
download=True)
2019

2120
# Data Loader (Input Pipeline)
2221
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
2322
batch_size=100,
2423
shuffle=True)
2524

26-
# 4x4 Convolution
27-
def conv4x4(in_channels, out_channels, stride):
28-
return nn.Conv2d(in_channels, out_channels, kernel_size=4,
29-
stride=stride, padding=1, bias=False)
30-
3125
# Discriminator Model
3226
class Discriminator(nn.Module):
3327
def __init__(self):
3428
super(Discriminator, self).__init__()
35-
self.model = nn.Sequential(
36-
conv4x4(3, 16, 2),
37-
nn.LeakyReLU(0.2, inplace=True),
38-
conv4x4(16, 32, 2),
39-
nn.BatchNorm2d(32),
40-
nn.LeakyReLU(0.2, inplace=True),
41-
conv4x4(32, 64, 2),
42-
nn.BatchNorm2d(64),
43-
nn.LeakyReLU(0.2, inplace=True),
44-
nn.Conv2d(64, 1, kernel_size=4),
45-
nn.Sigmoid())
46-
29+
self.fc1 = nn.Linear(784, 256)
30+
self.fc2 = nn.Linear(256, 1)
31+
4732
def forward(self, x):
48-
out = self.model(x)
49-
out = out.view(out.size(0), -1)
33+
h = F.relu(self.fc1(x))
34+
out = F.sigmoid(self.fc2(h))
5035
return out
5136

52-
# 4x4 Transpose convolution
53-
def conv_transpose4x4(in_channels, out_channels, stride=1, padding=1, bias=False):
54-
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4,
55-
stride=stride, padding=padding, bias=bias)
56-
5737
# Generator Model
5838
class Generator(nn.Module):
5939
def __init__(self):
6040
super(Generator, self).__init__()
61-
self.model = nn.Sequential(
62-
conv_transpose4x4(128, 64, padding=0),
63-
nn.BatchNorm2d(64),
64-
nn.ReLU(inplace=True),
65-
conv_transpose4x4(64, 32, 2),
66-
nn.BatchNorm2d(32),
67-
nn.ReLU(inplace=True),
68-
conv_transpose4x4(32, 16, 2),
69-
nn.BatchNorm2d(16),
70-
nn.ReLU(inplace=True),
71-
conv_transpose4x4(16, 3, 2, bias=True),
72-
nn.Tanh())
73-
41+
self.fc1 = nn.Linear(128, 256)
42+
self.fc2 = nn.Linear(256, 512)
43+
self.fc3 = nn.Linear(512, 784)
44+
7445
def forward(self, x):
75-
x = x.view(x.size(0), 128, 1, 1)
76-
out = self.model(x)
46+
h = F.leaky_relu(self.fc1(x))
47+
h = F.leaky_relu(self.fc2(h))
48+
out = F.tanh(self.fc3(h))
7749
return out
7850

7951
discriminator = Discriminator()
@@ -83,13 +55,14 @@ def forward(self, x):
8355

8456
# Loss and Optimizer
8557
criterion = nn.BCELoss()
86-
lr = 0.0002
87-
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
88-
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
58+
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0005)
59+
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0005)
8960

9061
# Training
91-
for epoch in range(50):
62+
for epoch in range(200):
9263
for i, (images, _) in enumerate(train_loader):
64+
# Build mini-batch dataset
65+
images = images.view(images.size(0), -1)
9366
images = Variable(images)
9467
real_labels = Variable(torch.ones(images.size(0)))
9568
fake_labels = Variable(torch.zeros(images.size(0)))
@@ -119,16 +92,17 @@ def forward(self, x):
11992
g_loss.backward()
12093
g_optimizer.step()
12194

122-
if (i+1) % 100 == 0:
95+
if (i+1) % 300 == 0:
12396
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
12497
'D(x): %.2f, D(G(z)): %.2f'
125-
%(epoch, 50, i+1, 500, d_loss.data[0], g_loss.data[0],
126-
real_score.data.mean(), fake_score.data.mean()))
98+
%(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
99+
real_score.data.mean(), fake_score.cpu().data.mean()))
127100

128-
# Save the sampled images
129-
torchvision.utils.save_image(fake_images.data,
130-
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
101+
# Save the sampled images
102+
fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
103+
torchvision.utils.save_image(fake_images.data,
104+
'./data2/fake_samples_%d.png' %epoch+1)
131105

132-
# Save the Models
106+
# Save the Models
133107
torch.save(generator.state_dict(), './generator.pkl')
134108
torch.save(discriminator.state_dict(), './discriminator.pkl')

0 commit comments

Comments
 (0)