Skip to content

Commit f3868f4

Browse files
committed
name changed: gan to dcgan
1 parent 8c4dd99 commit f3868f4

File tree

2 files changed

+268
-0
lines changed

2 files changed

+268
-0
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import torch
2+
import torchvision
3+
import torch.nn as nn
4+
import torchvision.datasets as dsets
5+
import torchvision.transforms as transforms
6+
from torch.autograd import Variable
7+
8+
# Image Preprocessing
9+
transform = transforms.Compose([
10+
transforms.Scale(36),
11+
transforms.RandomCrop(32),
12+
transforms.ToTensor(),
13+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
14+
15+
# CIFAR-10 Dataset
16+
train_dataset = dsets.CIFAR10(root='../data/',
17+
train=True,
18+
transform=transform,
19+
download=True)
20+
21+
# Data Loader (Input Pipeline)
22+
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
23+
batch_size=100,
24+
shuffle=True)
25+
26+
# 4x4 Convolution
27+
def conv4x4(in_channels, out_channels, stride):
28+
return nn.Conv2d(in_channels, out_channels, kernel_size=4,
29+
stride=stride, padding=1, bias=False)
30+
31+
# Discriminator Model
32+
class Discriminator(nn.Module):
33+
def __init__(self):
34+
super(Discriminator, self).__init__()
35+
self.model = nn.Sequential(
36+
conv4x4(3, 16, 2),
37+
nn.LeakyReLU(0.2, inplace=True),
38+
conv4x4(16, 32, 2),
39+
nn.BatchNorm2d(32),
40+
nn.LeakyReLU(0.2, inplace=True),
41+
conv4x4(32, 64, 2),
42+
nn.BatchNorm2d(64),
43+
nn.LeakyReLU(0.2, inplace=True),
44+
nn.Conv2d(64, 1, kernel_size=4),
45+
nn.Sigmoid())
46+
47+
def forward(self, x):
48+
out = self.model(x)
49+
out = out.view(out.size(0), -1)
50+
return out
51+
52+
# 4x4 Transpose convolution
53+
def conv_transpose4x4(in_channels, out_channels, stride=1, padding=1, bias=False):
54+
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4,
55+
stride=stride, padding=padding, bias=bias)
56+
57+
# Generator Model
58+
class Generator(nn.Module):
59+
def __init__(self):
60+
super(Generator, self).__init__()
61+
self.model = nn.Sequential(
62+
conv_transpose4x4(128, 64, padding=0),
63+
nn.BatchNorm2d(64),
64+
nn.ReLU(inplace=True),
65+
conv_transpose4x4(64, 32, 2),
66+
nn.BatchNorm2d(32),
67+
nn.ReLU(inplace=True),
68+
conv_transpose4x4(32, 16, 2),
69+
nn.BatchNorm2d(16),
70+
nn.ReLU(inplace=True),
71+
conv_transpose4x4(16, 3, 2, bias=True),
72+
nn.Tanh())
73+
74+
def forward(self, x):
75+
x = x.view(x.size(0), 128, 1, 1)
76+
out = self.model(x)
77+
return out
78+
79+
discriminator = Discriminator()
80+
generator = Generator()
81+
discriminator.cuda()
82+
generator.cuda()
83+
84+
# Loss and Optimizer
85+
criterion = nn.BCELoss()
86+
lr = 0.002
87+
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
88+
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
89+
90+
# Training
91+
for epoch in range(50):
92+
for i, (images, _) in enumerate(train_loader):
93+
images = Variable(images.cuda())
94+
real_labels = Variable(torch.ones(images.size(0))).cuda()
95+
fake_labels = Variable(torch.zeros(images.size(0))).cuda()
96+
97+
# Train the discriminator
98+
discriminator.zero_grad()
99+
outputs = discriminator(images)
100+
real_loss = criterion(outputs, real_labels)
101+
real_score = outputs
102+
103+
noise = Variable(torch.randn(images.size(0), 128)).cuda()
104+
fake_images = generator(noise)
105+
outputs = discriminator(fake_images)
106+
fake_loss = criterion(outputs, fake_labels)
107+
fake_score = outputs
108+
109+
d_loss = real_loss + fake_loss
110+
d_loss.backward()
111+
d_optimizer.step()
112+
113+
# Train the generator
114+
generator.zero_grad()
115+
noise = Variable(torch.randn(images.size(0), 128)).cuda()
116+
fake_images = generator(noise)
117+
outputs = discriminator(fake_images)
118+
g_loss = criterion(outputs, real_labels)
119+
g_loss.backward()
120+
g_optimizer.step()
121+
122+
if (i+1) % 100 == 0:
123+
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
124+
'D(x): %.2f, D(G(z)): %.2f'
125+
%(epoch, 50, i+1, 500, d_loss.data[0], g_loss.data[0],
126+
real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
127+
128+
# Save the sampled images
129+
torchvision.utils.save_image(fake_images.data,
130+
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
131+
132+
# Save the Models
133+
torch.save(generator.state_dict(), './generator.pkl')
134+
torch.save(discriminator.state_dict(), './discriminator.pkl')
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import torch
2+
import torchvision
3+
import torch.nn as nn
4+
import torchvision.datasets as dsets
5+
import torchvision.transforms as transforms
6+
from torch.autograd import Variable
7+
8+
# Image Preprocessing
9+
transform = transforms.Compose([
10+
transforms.Scale(36),
11+
transforms.RandomCrop(32),
12+
transforms.ToTensor(),
13+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
14+
15+
# CIFAR-10 Dataset
16+
train_dataset = dsets.CIFAR10(root='../data/',
17+
train=True,
18+
transform=transform,
19+
download=True)
20+
21+
# Data Loader (Input Pipeline)
22+
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
23+
batch_size=100,
24+
shuffle=True)
25+
26+
# 4x4 Convolution
27+
def conv4x4(in_channels, out_channels, stride):
28+
return nn.Conv2d(in_channels, out_channels, kernel_size=4,
29+
stride=stride, padding=1, bias=False)
30+
31+
# Discriminator Model
32+
class Discriminator(nn.Module):
33+
def __init__(self):
34+
super(Discriminator, self).__init__()
35+
self.model = nn.Sequential(
36+
conv4x4(3, 16, 2),
37+
nn.LeakyReLU(0.2, inplace=True),
38+
conv4x4(16, 32, 2),
39+
nn.BatchNorm2d(32),
40+
nn.LeakyReLU(0.2, inplace=True),
41+
conv4x4(32, 64, 2),
42+
nn.BatchNorm2d(64),
43+
nn.LeakyReLU(0.2, inplace=True),
44+
nn.Conv2d(64, 1, kernel_size=4),
45+
nn.Sigmoid())
46+
47+
def forward(self, x):
48+
out = self.model(x)
49+
out = out.view(out.size(0), -1)
50+
return out
51+
52+
# 4x4 Transpose convolution
53+
def conv_transpose4x4(in_channels, out_channels, stride=1, padding=1, bias=False):
54+
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4,
55+
stride=stride, padding=padding, bias=bias)
56+
57+
# Generator Model
58+
class Generator(nn.Module):
59+
def __init__(self):
60+
super(Generator, self).__init__()
61+
self.model = nn.Sequential(
62+
conv_transpose4x4(128, 64, padding=0),
63+
nn.BatchNorm2d(64),
64+
nn.ReLU(inplace=True),
65+
conv_transpose4x4(64, 32, 2),
66+
nn.BatchNorm2d(32),
67+
nn.ReLU(inplace=True),
68+
conv_transpose4x4(32, 16, 2),
69+
nn.BatchNorm2d(16),
70+
nn.ReLU(inplace=True),
71+
conv_transpose4x4(16, 3, 2, bias=True),
72+
nn.Tanh())
73+
74+
def forward(self, x):
75+
x = x.view(x.size(0), 128, 1, 1)
76+
out = self.model(x)
77+
return out
78+
79+
discriminator = Discriminator()
80+
generator = Generator()
81+
82+
83+
84+
# Loss and Optimizer
85+
criterion = nn.BCELoss()
86+
lr = 0.0002
87+
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
88+
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
89+
90+
# Training
91+
for epoch in range(50):
92+
for i, (images, _) in enumerate(train_loader):
93+
images = Variable(images)
94+
real_labels = Variable(torch.ones(images.size(0)))
95+
fake_labels = Variable(torch.zeros(images.size(0)))
96+
97+
# Train the discriminator
98+
discriminator.zero_grad()
99+
outputs = discriminator(images)
100+
real_loss = criterion(outputs, real_labels)
101+
real_score = outputs
102+
103+
noise = Variable(torch.randn(images.size(0), 128))
104+
fake_images = generator(noise)
105+
outputs = discriminator(fake_images)
106+
fake_loss = criterion(outputs, fake_labels)
107+
fake_score = outputs
108+
109+
d_loss = real_loss + fake_loss
110+
d_loss.backward()
111+
d_optimizer.step()
112+
113+
# Train the generator
114+
generator.zero_grad()
115+
noise = Variable(torch.randn(images.size(0), 128))
116+
fake_images = generator(noise)
117+
outputs = discriminator(fake_images)
118+
g_loss = criterion(outputs, real_labels)
119+
g_loss.backward()
120+
g_optimizer.step()
121+
122+
if (i+1) % 100 == 0:
123+
print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '
124+
'D(x): %.2f, D(G(z)): %.2f'
125+
%(epoch, 50, i+1, 500, d_loss.data[0], g_loss.data[0],
126+
real_score.data.mean(), fake_score.data.mean()))
127+
128+
# Save the sampled images
129+
torchvision.utils.save_image(fake_images.data,
130+
'./data/fake_samples_%d_%d.png' %(epoch+1, i+1))
131+
132+
# Save the Models
133+
torch.save(generator.state_dict(), './generator.pkl')
134+
torch.save(discriminator.state_dict(), './discriminator.pkl')

0 commit comments

Comments
 (0)