Skip to content

Commit 401c8a2

Browse files
authored
Add files via upload
1 parent cac37ee commit 401c8a2

File tree

2 files changed

+180
-0
lines changed

2 files changed

+180
-0
lines changed

examples/VAE/img_epoch20.jpg

193 KB
Loading

examples/VAE/vae_mnist.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""
2+
Variational Autoencoder for MNIST data
3+
reference: https://jmetzen.github.io/2015-11-27/vae.html
4+
2017/01/17
5+
"""
6+
import sys
7+
import numpy as np
8+
import tensorflow as tf
9+
import matplotlib.pyplot as plt
10+
11+
from input_data import read_data_sets
12+
13+
# Random seeds for reproduce
14+
np.random.seed(2017)
15+
tf.set_random_seed(2017)
16+
17+
class VAE(object):
18+
"""A simple class of variational autoencoder"""
19+
def __init__(self, input_dim=784, z_dim=50, batch_size=100, encoder_hidden_size=[500, 500],
20+
decoder_hidden_size=[500, 500], act_fn=tf.nn.softplus):
21+
"""
22+
:param input_dim: int, the dimension of input
23+
:param z_dim: int, the dimension of latent space
24+
:param batch_size: int, batch size
25+
:param encoder_hidden_size: list or tuple, the number of hidden units in encoder
26+
:param decoder_hidden_size: list or tuple, the number of hidden units in decoder
27+
:param act_fn: the activation function
28+
"""
29+
self.input_dim = input_dim
30+
self.z_dim = z_dim
31+
self.batch_size = batch_size
32+
self.encoder_hidden_size = encoder_hidden_size
33+
self.decoder_hidden_size = decoder_hidden_size
34+
self.act_fn = act_fn
35+
36+
self._bulid_model()
37+
38+
def _bulid_model(self):
39+
"""The inner function to build the model"""
40+
# Input placeholder
41+
self.x = tf.placeholder(tf.float32, shape=[self.batch_size, self.input_dim])
42+
# The encoder: determine the mean and (log) variance of Gaussian distribution
43+
self.z_mean, self.z_log_sigma_sq = self._encoder(self.x)
44+
# Sampling from Gaussian distribution
45+
eps = tf.random_normal([self.batch_size, self.z_dim], mean=0.0, stddev=1.0)
46+
# z = mean + sigma*epsilon
47+
self.z = tf.add(self.z_mean, tf.mul(tf.sqrt(tf.exp(self.z_log_sigma_sq)), eps))
48+
49+
# Decoder: determine the mean of Bernoulli distribution of reconstructed input
50+
self.x_reconstr_mean = self._decoder(self.z)
51+
52+
# Compute the loss
53+
with tf.name_scope("loss"):
54+
# The reconstruction loss: cross entropy
55+
reconstr_loss = -tf.reduce_sum(self.x * tf.log(1e-10 + self.x_reconstr_mean) + \
56+
(1.0 - self.x) * tf.log(1e-10 + 1.0 - self.x_reconstr_mean), axis=1)
57+
# The latent loss: KL divergence
58+
latent_loss = -0.5 * tf.reduce_sum(1.0 + self.z_log_sigma_sq - tf.square(self.z_mean) - \
59+
tf.exp(self.z_log_sigma_sq), axis=1)
60+
# Average over the batch
61+
self.cost = tf.reduce_mean(reconstr_loss + latent_loss)
62+
63+
# The optimizer
64+
self.lr = tf.Variable(0.001, trainable=False)
65+
vars = tf.trainable_variables()
66+
self.train_op = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.cost, var_list=vars)
67+
68+
def _encoder(self, x, name="encoder"):
69+
"""Encoder"""
70+
with tf.variable_scope(name):
71+
n_in = self.input_dim
72+
for i, s in enumerate(self.encoder_hidden_size):
73+
w, b = self._get_vars(n_in, s, name="h{0}".format(i))
74+
if i == 0:
75+
h = self.act_fn(tf.nn.xw_plus_b(x, w, b))
76+
else:
77+
h = self.act_fn(tf.nn.xw_plus_b(h, w, b))
78+
n_in = s
79+
w, b = self._get_vars(n_in, self.z_dim, name="out_mean")
80+
z_mean = tf.nn.xw_plus_b(h, w, b)
81+
w, b = self._get_vars(n_in, self.z_dim, name="out_log_sigma")
82+
z_log_sigma_sq = tf.nn.xw_plus_b(h, w, b)
83+
return z_mean, z_log_sigma_sq
84+
85+
def _decoder(self, z, name="decoder"):
86+
"""Decoder"""
87+
with tf.variable_scope(name):
88+
n_in = self.z_dim
89+
for i, s in enumerate(self.decoder_hidden_size):
90+
w, b = self._get_vars(n_in, s, name="h{0}".format(i))
91+
if i == 0:
92+
h = self.act_fn(tf.nn.xw_plus_b(z, w, b))
93+
else:
94+
h = self.act_fn(tf.nn.xw_plus_b(h, w, b))
95+
n_in = s
96+
# Use sigmoid for Bernoulli distribution
97+
w, b = self._get_vars(n_in, self.input_dim, name="out_mean")
98+
x_reconstr_mean = tf.nn.sigmoid(tf.nn.xw_plus_b(h, w, b))
99+
return x_reconstr_mean
100+
101+
def _get_vars(self, n_in, n_out, name=""):
102+
"""
103+
Create weight and bias variables
104+
"""
105+
with tf.variable_scope(name):
106+
w = tf.get_variable("w", [n_in, n_out], initializer=tf.contrib.layers.xavier_initializer())
107+
b = tf.get_variable("b", [n_out,], initializer=tf.constant_initializer(0.1))
108+
return w, b
109+
110+
if __name__ == "__main__":
111+
n_epochs = 30
112+
lr = 0.001
113+
batch_size = 100
114+
display_every = 1
115+
116+
path = sys.path[0]
117+
mnist = read_data_sets("MNIST_data/", one_hot=True)
118+
with tf.Session() as sess:
119+
vae = VAE(input_dim=784, z_dim=2, batch_size=batch_size, encoder_hidden_size=[500, 500],
120+
decoder_hidden_size=[500, 500], act_fn=tf.nn.softplus)
121+
sess.run(tf.global_variables_initializer())
122+
saver = tf.train.Saver()
123+
#saver.restore(sess, save_path=path+"/model/model.ckpt")
124+
# Start training
125+
print("Start training...")
126+
total_batch = int(mnist.train.num_examples/batch_size)
127+
for epoch in range(n_epochs):
128+
avg_cost = 0.0
129+
# For each batch
130+
for i in range(total_batch):
131+
batch_xs, _ = mnist.train.next_batch(batch_size)
132+
c, _ = sess.run([vae.cost, vae.train_op], feed_dict={vae.x: batch_xs})
133+
avg_cost += c/total_batch
134+
if epoch % display_every == 0:
135+
save_path = saver.save(sess, path+"/model/model.ckpt")
136+
#print("\tModel saved in file: {0}".format(save_path))
137+
print("\tEpoch {0}, cost {1}".format(epoch, avg_cost))
138+
139+
# Sampling
140+
x_sample, _ = mnist.test.next_batch(batch_size)
141+
x_reconstr = sess.run(vae.x_reconstr_mean, feed_dict={vae.x: x_sample})
142+
plt.figure(figsize=(8, 12))
143+
for i in range(5):
144+
plt.subplot(5, 2, 2*i + 1)
145+
plt.imshow(np.reshape(x_sample[i],(28, 28)), vmin=0, vmax=1, cmap="gray")
146+
plt.title("Test input")
147+
plt.colorbar()
148+
plt.subplot(5, 2, 2*i + 2)
149+
plt.imshow(np.reshape(x_reconstr[i], [28, 28]), vmin=0, vmax=1, cmap="gray")
150+
plt.title("Reconstruction")
151+
plt.colorbar()
152+
plt.tight_layout()
153+
plt.savefig(path+"/results/img_epoch{0}.jpg".format(n_epochs))
154+
plt.show()
155+
156+
# Random sampling
157+
nx, ny = 20, 20
158+
xs = np.linspace(-3, 3, nx)
159+
ys = np.linspace(-3, 3, ny)
160+
xs, ys = np.meshgrid(xs, ys)
161+
xs = np.reshape(xs, [-1, 1])
162+
ys = np.reshape(ys, [-1, 1])
163+
zs = np.concatenate((xs, ys), axis=1)
164+
165+
canvas = np.zeros((28*ny, 28*nx))
166+
for i in range(4):
167+
z_mu = zs[batch_size*i:batch_size*(i+1), :]
168+
x_mean = sess.run(vae.x_reconstr_mean, feed_dict={vae.z: z_mu})
169+
canvas[(ny-(i+1)*5)*28:(ny-i*5)*28] = x_mean.reshape(-1, 28*nx)[::-1]
170+
171+
plt.figure(figsize=(8, 10))
172+
plt.imshow(canvas, origin="upper", vmin=0, vmax=1, interpolation='none', cmap='gray')
173+
plt.tight_layout()
174+
plt.savefig(path+"/results/rand_img_epoch{0}.jpg".format(n_epochs))
175+
plt.show()
176+
177+
178+
179+
180+

0 commit comments

Comments
 (0)