1+ """
2+ 2017/01/09
3+ """
4+ import sys
5+ import numpy as np
6+ import tensorflow as tf
7+ from keras .datasets import mnist
8+ from PIL import Image
9+
10+ # Batch normalization
11+ def batch_norm (inpt , epsilon = 1e-05 , decay = 0.9 , is_training = True , name = "batch_norm" ):
12+ """
13+ Implements the bacth normalization
14+ The input is 4-D tensor
15+ """
16+ bn = tf .contrib .layers .batch_norm (inpt , decay = decay , updates_collections = None ,
17+ epsilon = epsilon , scale = True , is_training = is_training , scope = name )
18+ return bn
19+
20+ # Convolution 2-D
21+ def conv2d (inpt , nb_filter , filter_size = 5 , strides = 2 , bias = True , stddev = 0.02 , padding = "SAME" ,
22+ name = "conv2d" ):
23+ in_channels = inpt .get_shape ().as_list ()[- 1 ]
24+ with tf .variable_scope (name ):
25+ w = tf .get_variable ("w" , shape = [filter_size , filter_size , in_channels , nb_filter ],
26+ initializer = tf .truncated_normal_initializer (mean = 0.0 , stddev = stddev ))
27+ conv = tf .nn .conv2d (inpt , w , strides = [1 , strides , strides , 1 ], padding = padding )
28+ if bias :
29+ b = tf .get_variable ("b" , shape = [nb_filter ,], initializer = tf .constant_initializer (0.0 ))
30+ conv = tf .nn .bias_add (conv , b )
31+ return conv
32+
33+ # Convolution 2D Transpose
34+ def deconv2d (inpt , output_shape , filter_size = 5 , strides = 2 , bias = True , stddev = 0.02 ,
35+ padding = "SAME" , name = "deconv2d" ):
36+
37+ in_channels = inpt .get_shape ().as_list ()[- 1 ]
38+ with tf .variable_scope (name ):
39+ # Note: filter with shape [height, width, output_channels, in_channels]
40+ w = tf .get_variable ("w" , shape = [filter_size , filter_size , output_shape [- 1 ], in_channels ],
41+ initializer = tf .truncated_normal_initializer (mean = 0.0 , stddev = stddev ))
42+ deconv = tf .nn .conv2d_transpose (inpt , w , output_shape = output_shape , strides = [1 , strides , strides , 1 ],
43+ padding = padding )
44+ if bias :
45+ b = tf .get_variable ("b" , shape = [output_shape [- 1 ]], initializer = tf .constant_initializer (0.0 ))
46+ deconv = tf .nn .bias_add (deconv , b )
47+ return deconv
48+
49+ # Leaky ReLU
50+ def lrelu (x , leak = 0.2 , name = "lrelu" ):
51+ return tf .maximum (x , x * leak )
52+
53+ # Linear
54+ def linear (x , output_dim , stddev = 0.02 , name = "linear" ):
55+ input_dim = x .get_shape ().as_list ()[- 1 ]
56+ with tf .variable_scope (name ):
57+ w = tf .get_variable ("w" , shape = [input_dim , output_dim ], initializer = tf .random_normal_initializer (stddev = stddev ))
58+ b = tf .get_variable ("b" , shape = [output_dim ,], initializer = tf .constant_initializer (0.0 ))
59+ return tf .nn .xw_plus_b (x , w , b )
60+
61+ class DCGAN (object ):
62+ """A class of DCGAN model"""
63+ def __init__ (self , z_dim = 100 , output_dim = 28 , batch_size = 100 , c_dim = 1 , df_dim = 64 , gf_dim = 64 , gfc_dim = 1024 ,
64+ dfc_dim = 1024 , n_conv = 2 , n_deconv = 2 ):
65+ """
66+ :param z_dim: int, the dimension of z (the noise input of generator)
67+ :param output_dim: int,
68+ """
69+ self .z_dim = z_dim
70+ self .output_dim = output_dim
71+ self .c_dim = c_dim
72+ self .df_dim = df_dim
73+ self .gf_dim = gf_dim
74+ self .dfc_dim = dfc_dim
75+ self .n_conv = n_conv
76+ self .n_deconv = n_deconv
77+ self .batch_size = batch_size
78+
79+ self ._build_model ()
80+
81+ def _build_model (self ):
82+ # input
83+ self .z = tf .placeholder (tf .float32 , shape = [self .batch_size , self .z_dim ])
84+ self .x = tf .placeholder (tf .float32 , shape = [self .batch_size , self .output_dim ,
85+ self .output_dim , self .c_dim ])
86+
87+ # G
88+ self .G = self ._generator (self .z )
89+ # D
90+ self .D1 , d1_logits = self ._discriminator (self .x , reuse = False )
91+ self .D2 , d2_logits = self ._discriminator (self .G , reuse = True )
92+
93+ self .g_loss = tf .reduce_mean (tf .nn .sigmoid_cross_entropy_with_logits (d2_logits , tf .ones_like (self .D2 )))
94+ real_loss = tf .nn .sigmoid_cross_entropy_with_logits (d1_logits , tf .ones_like (self .D1 ))
95+ fake_loss = tf .nn .sigmoid_cross_entropy_with_logits (d2_logits , tf .zeros_like (self .D2 ))
96+ self .d_loss = tf .reduce_mean (real_loss + fake_loss )
97+
98+ vars = tf .trainable_variables ()
99+ self .d_vars = [v for v in vars if "D" in v .name ]
100+ self .g_vars = [v for v in vars if "G" in v .name ]
101+
102+ def _discriminator (self , input , reuse = False ):
103+ with tf .variable_scope ("D" , reuse = reuse ):
104+ h = lrelu (conv2d (input , nb_filter = self .df_dim , name = "d_conv0" ))
105+ for i in range (1 , self .n_conv + 1 ):
106+ conv = conv2d (h , nb_filter = self .df_dim * (2 ** i ), name = "d_conv{0}" .format (i ))
107+ h = lrelu (batch_norm (conv , name = "d_bn{0}" .format (i )))
108+ h = linear (tf .reshape (h , shape = [self .batch_size , - 1 ]), self .dfc_dim , name = "d_lin0" )
109+ h = linear (tf .nn .tanh (h ), 1 , name = "d_lin1" )
110+ return tf .nn .sigmoid (h ), h
111+
112+ def _generator (self , input ):
113+ with tf .variable_scope ("G" ):
114+ nb_fliters = [self .gf_dim ]
115+ s = [self .output_dim ]
116+ for i in range (1 , self .n_deconv ):
117+ nb_fliters .append (nb_fliters [- 1 ]* 2 )
118+ s .append (s [- 1 ]// 2 )
119+ s .append (s [- 1 ]// 2 )
120+ h = linear (input , nb_fliters [- 1 ]* s [- 1 ]* s [- 1 ], name = "g_lin0" )
121+ h = tf .nn .relu (batch_norm (tf .reshape (h , shape = [- 1 , s [- 1 ], s [- 1 ], nb_fliters [- 1 ]]), name = "g_bn0" ))
122+ for i in range (1 , self .n_deconv ):
123+ h = deconv2d (h , [self .batch_size , s [- i - 1 ], s [- i - 1 ], nb_fliters [- i - 1 ]],
124+ name = "g_deconv{0}" .format (i - 1 ))
125+ h = tf .nn .relu (batch_norm (h , name = "g_bn{0}" .format (i )))
126+
127+ h = deconv2d (h , [self .batch_size , s [0 ], s [0 ], self .c_dim ], name = "g_deconv{0}" .format (self .n_deconv - 1 ))
128+ return tf .nn .tanh (h )
129+
130+ def combine_images (images ):
131+ num = images .shape [0 ]
132+ width = int (np .sqrt (num ))
133+ height = int (np .ceil (num / width ))
134+ h , w = images .shape [1 :- 1 ]
135+ img = np .zeros ((height * h , width * w ), dtype = images .dtype )
136+ for index , m in enumerate (images ):
137+ i = int (index / width )
138+ j = index % width
139+ img [i * h :(i + 1 )* h , j * w :(j + 1 )* w ] = m [:, :, 0 ]
140+ return img
141+
142+
143+ if __name__ == "__main__" :
144+ (X_train , y_train ), (X_test , y_test ) = mnist .load_data ()
145+ X_train = (np .asarray (X_train , dtype = np .float32 ) - 127.5 )/ 127.5
146+ X_train = np .reshape (X_train , [- 1 , 28 , 28 , 1 ])
147+
148+ z_dim = 100
149+ batch_size = 128
150+ lr = 0.0002
151+ n_epochs = 10
152+
153+ sess = tf .Session ()
154+ dcgan = DCGAN (z_dim = z_dim , output_dim = 28 , batch_size = 128 , c_dim = 1 )
155+
156+ d_train_op = tf .train .AdamOptimizer (lr , beta1 = 0.5 ).minimize (dcgan .d_loss ,
157+ var_list = dcgan .d_vars )
158+ g_train_op = tf .train .AdamOptimizer (lr , beta1 = 0.5 ).minimize (dcgan .g_loss ,
159+ var_list = dcgan .g_vars )
160+ sess .run (tf .global_variables_initializer ())
161+
162+ num_batches = int (len (X_train )/ batch_size )
163+
164+ for epoch in range (n_epochs ):
165+ print ("Epoch" , epoch )
166+ d_losses = 0
167+ g_losses = 0
168+ for idx in range (num_batches ):
169+ # Train D
170+ z = np .random .uniform (- 1 , 1 , size = [batch_size , z_dim ])
171+ x = X_train [idx * batch_size :(idx + 1 )* batch_size ]
172+ _ , d_loss = sess .run ([d_train_op , dcgan .d_loss ], feed_dict = {dcgan .z : z ,
173+ dcgan .x : x })
174+ d_losses += d_loss / num_batches
175+ # Train G
176+ z = np .random .uniform (- 1 , 1 , size = [batch_size , z_dim ])
177+ _ , g_loss = sess .run ([g_train_op , dcgan .g_loss ], feed_dict = {dcgan .z : z })
178+ g_losses += g_loss / num_batches
179+
180+ print ("\t d_loss {0}, g_loss {1}" .format (d_losses , g_losses ))
181+ # Generate images
182+ z = np .random .uniform (- 1 , 1 , size = [batch_size , z_dim ])
183+ images = sess .run (dcgan .G , feed_dict = {dcgan .z : z })
184+ img = combine_images (images )
185+ img = img * 127.5 + 127.5
186+ Image .fromarray (img .astype (np .uint8 )).save ("epoch{0}_g_images.png" .format (epoch ))
187+
188+
189+
190+
191+
192+
193+
194+
195+
196+
197+
198+
199+
200+
0 commit comments