Skip to content

Commit bf919f6

Browse files
authored
Add files via upload
1 parent 6c014e4 commit bf919f6

File tree

1 file changed

+200
-0
lines changed

1 file changed

+200
-0
lines changed

examples/gan/DCGAN.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
"""
2+
2017/01/09
3+
"""
4+
import sys
5+
import numpy as np
6+
import tensorflow as tf
7+
from keras.datasets import mnist
8+
from PIL import Image
9+
10+
# Batch normalization
11+
def batch_norm(inpt, epsilon=1e-05, decay=0.9, is_training=True, name="batch_norm"):
12+
"""
13+
Implements the bacth normalization
14+
The input is 4-D tensor
15+
"""
16+
bn = tf.contrib.layers.batch_norm(inpt, decay=decay, updates_collections=None,
17+
epsilon=epsilon, scale=True, is_training=is_training, scope=name)
18+
return bn
19+
20+
# Convolution 2-D
21+
def conv2d(inpt, nb_filter, filter_size=5, strides=2, bias=True, stddev=0.02, padding="SAME",
22+
name="conv2d"):
23+
in_channels = inpt.get_shape().as_list()[-1]
24+
with tf.variable_scope(name):
25+
w = tf.get_variable("w", shape=[filter_size, filter_size, in_channels, nb_filter],
26+
initializer=tf.truncated_normal_initializer(mean=0.0, stddev=stddev))
27+
conv = tf.nn.conv2d(inpt, w, strides=[1, strides, strides, 1], padding=padding)
28+
if bias:
29+
b = tf.get_variable("b", shape=[nb_filter,], initializer=tf.constant_initializer(0.0))
30+
conv = tf.nn.bias_add(conv, b)
31+
return conv
32+
33+
# Convolution 2D Transpose
34+
def deconv2d(inpt, output_shape, filter_size=5, strides=2, bias=True, stddev=0.02,
35+
padding="SAME", name="deconv2d"):
36+
37+
in_channels = inpt.get_shape().as_list()[-1]
38+
with tf.variable_scope(name):
39+
# Note: filter with shape [height, width, output_channels, in_channels]
40+
w = tf.get_variable("w", shape=[filter_size, filter_size, output_shape[-1], in_channels],
41+
initializer=tf.truncated_normal_initializer(mean=0.0, stddev=stddev))
42+
deconv = tf.nn.conv2d_transpose(inpt, w, output_shape=output_shape, strides=[1, strides, strides, 1],
43+
padding=padding)
44+
if bias:
45+
b = tf.get_variable("b", shape=[output_shape[-1]], initializer=tf.constant_initializer(0.0))
46+
deconv = tf.nn.bias_add(deconv, b)
47+
return deconv
48+
49+
# Leaky ReLU
50+
def lrelu(x, leak=0.2, name="lrelu"):
51+
return tf.maximum(x, x*leak)
52+
53+
# Linear
54+
def linear(x, output_dim, stddev=0.02, name="linear"):
55+
input_dim = x.get_shape().as_list()[-1]
56+
with tf.variable_scope(name):
57+
w = tf.get_variable("w", shape=[input_dim, output_dim], initializer=tf.random_normal_initializer(stddev=stddev))
58+
b = tf.get_variable("b", shape=[output_dim,], initializer=tf.constant_initializer(0.0))
59+
return tf.nn.xw_plus_b(x, w, b)
60+
61+
class DCGAN(object):
62+
"""A class of DCGAN model"""
63+
def __init__(self, z_dim=100, output_dim=28, batch_size=100, c_dim=1, df_dim=64, gf_dim=64, gfc_dim=1024,
64+
dfc_dim=1024, n_conv=2, n_deconv=2):
65+
"""
66+
:param z_dim: int, the dimension of z (the noise input of generator)
67+
:param output_dim: int,
68+
"""
69+
self.z_dim = z_dim
70+
self.output_dim = output_dim
71+
self.c_dim = c_dim
72+
self.df_dim = df_dim
73+
self.gf_dim = gf_dim
74+
self.dfc_dim = dfc_dim
75+
self.n_conv = n_conv
76+
self.n_deconv = n_deconv
77+
self.batch_size = batch_size
78+
79+
self._build_model()
80+
81+
def _build_model(self):
82+
# input
83+
self.z = tf.placeholder(tf.float32, shape=[self.batch_size, self.z_dim])
84+
self.x = tf.placeholder(tf.float32, shape=[self.batch_size, self.output_dim,
85+
self.output_dim, self.c_dim])
86+
87+
# G
88+
self.G = self._generator(self.z)
89+
# D
90+
self.D1, d1_logits = self._discriminator(self.x, reuse=False)
91+
self.D2, d2_logits = self._discriminator(self.G, reuse=True)
92+
93+
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(d2_logits, tf.ones_like(self.D2)))
94+
real_loss = tf.nn.sigmoid_cross_entropy_with_logits(d1_logits, tf.ones_like(self.D1))
95+
fake_loss = tf.nn.sigmoid_cross_entropy_with_logits(d2_logits, tf.zeros_like(self.D2))
96+
self.d_loss = tf.reduce_mean(real_loss + fake_loss)
97+
98+
vars = tf.trainable_variables()
99+
self.d_vars = [v for v in vars if "D" in v.name]
100+
self.g_vars = [v for v in vars if "G" in v.name]
101+
102+
def _discriminator(self, input, reuse=False):
103+
with tf.variable_scope("D", reuse=reuse):
104+
h = lrelu(conv2d(input, nb_filter=self.df_dim, name="d_conv0"))
105+
for i in range(1, self.n_conv+1):
106+
conv = conv2d(h, nb_filter=self.df_dim*(2**i), name="d_conv{0}".format(i))
107+
h = lrelu(batch_norm(conv, name="d_bn{0}".format(i)))
108+
h = linear(tf.reshape(h, shape=[self.batch_size, -1]), self.dfc_dim, name="d_lin0")
109+
h = linear(tf.nn.tanh(h), 1, name="d_lin1")
110+
return tf.nn.sigmoid(h), h
111+
112+
def _generator(self, input):
113+
with tf.variable_scope("G"):
114+
nb_fliters = [self.gf_dim]
115+
s = [self.output_dim]
116+
for i in range(1, self.n_deconv):
117+
nb_fliters.append(nb_fliters[-1]*2)
118+
s.append(s[-1]//2)
119+
s.append(s[-1]//2)
120+
h = linear(input, nb_fliters[-1]*s[-1]*s[-1], name="g_lin0")
121+
h = tf.nn.relu(batch_norm(tf.reshape(h, shape=[-1, s[-1], s[-1], nb_fliters[-1]]), name="g_bn0"))
122+
for i in range(1, self.n_deconv):
123+
h = deconv2d(h, [self.batch_size, s[-i-1], s[-i-1], nb_fliters[-i-1]],
124+
name="g_deconv{0}".format(i-1))
125+
h = tf.nn.relu(batch_norm(h, name="g_bn{0}".format(i)))
126+
127+
h = deconv2d(h, [self.batch_size, s[0], s[0], self.c_dim], name="g_deconv{0}".format(self.n_deconv-1))
128+
return tf.nn.tanh(h)
129+
130+
def combine_images(images):
131+
num = images.shape[0]
132+
width = int(np.sqrt(num))
133+
height = int(np.ceil(num/width))
134+
h, w = images.shape[1:-1]
135+
img = np.zeros((height*h, width*w), dtype=images.dtype)
136+
for index, m in enumerate(images):
137+
i = int(index/width)
138+
j = index % width
139+
img[i*h:(i+1)*h, j*w:(j+1)*w] = m[:, :, 0]
140+
return img
141+
142+
143+
if __name__ == "__main__":
144+
(X_train, y_train), (X_test, y_test) = mnist.load_data()
145+
X_train = (np.asarray(X_train, dtype=np.float32) - 127.5)/127.5
146+
X_train = np.reshape(X_train, [-1, 28, 28, 1])
147+
148+
z_dim = 100
149+
batch_size = 128
150+
lr = 0.0002
151+
n_epochs = 10
152+
153+
sess = tf.Session()
154+
dcgan = DCGAN(z_dim=z_dim, output_dim=28, batch_size=128, c_dim=1)
155+
156+
d_train_op = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(dcgan.d_loss,
157+
var_list=dcgan.d_vars)
158+
g_train_op = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(dcgan.g_loss,
159+
var_list=dcgan.g_vars)
160+
sess.run(tf.global_variables_initializer())
161+
162+
num_batches = int(len(X_train)/batch_size)
163+
164+
for epoch in range(n_epochs):
165+
print("Epoch", epoch)
166+
d_losses = 0
167+
g_losses = 0
168+
for idx in range(num_batches):
169+
# Train D
170+
z = np.random.uniform(-1, 1, size=[batch_size, z_dim])
171+
x = X_train[idx*batch_size:(idx+1)*batch_size]
172+
_, d_loss = sess.run([d_train_op, dcgan.d_loss], feed_dict={dcgan.z: z,
173+
dcgan.x: x})
174+
d_losses += d_loss/num_batches
175+
# Train G
176+
z = np.random.uniform(-1, 1, size=[batch_size, z_dim])
177+
_, g_loss = sess.run([g_train_op, dcgan.g_loss], feed_dict={dcgan.z: z})
178+
g_losses += g_loss/num_batches
179+
180+
print("\td_loss {0}, g_loss {1}".format(d_losses, g_losses))
181+
# Generate images
182+
z = np.random.uniform(-1, 1, size=[batch_size, z_dim])
183+
images = sess.run(dcgan.G, feed_dict={dcgan.z: z})
184+
img = combine_images(images)
185+
img = img*127.5 + 127.5
186+
Image.fromarray(img.astype(np.uint8)).save("epoch{0}_g_images.png".format(epoch))
187+
188+
189+
190+
191+
192+
193+
194+
195+
196+
197+
198+
199+
200+

0 commit comments

Comments
 (0)