| 
 | 1 | +"""  | 
 | 2 | +Variational Autoencoder for MNIST data  | 
 | 3 | +reference: https://jmetzen.github.io/2015-11-27/vae.html  | 
 | 4 | +2017/01/17  | 
 | 5 | +"""  | 
 | 6 | +import sys  | 
 | 7 | +import numpy as np  | 
 | 8 | +import tensorflow as tf  | 
 | 9 | +import matplotlib.pyplot as plt  | 
 | 10 | + | 
 | 11 | +from input_data import read_data_sets  | 
 | 12 | + | 
 | 13 | +# Random seeds for reproduce  | 
 | 14 | +np.random.seed(2017)  | 
 | 15 | +tf.set_random_seed(2017)  | 
 | 16 | + | 
 | 17 | +class VAE(object):  | 
 | 18 | +    """A simple class of variational autoencoder"""  | 
 | 19 | +    def __init__(self, input_dim=784, z_dim=50, batch_size=100, encoder_hidden_size=[500, 500],   | 
 | 20 | +                    decoder_hidden_size=[500, 500], act_fn=tf.nn.softplus):  | 
 | 21 | +        """  | 
 | 22 | +        :param input_dim: int, the dimension of input  | 
 | 23 | +        :param z_dim: int, the dimension of latent space  | 
 | 24 | +        :param batch_size: int, batch size  | 
 | 25 | +        :param encoder_hidden_size: list or tuple, the number of hidden units in encoder  | 
 | 26 | +        :param decoder_hidden_size: list or tuple, the number of hidden units in decoder  | 
 | 27 | +        :param act_fn: the activation function  | 
 | 28 | +        """  | 
 | 29 | +        self.input_dim = input_dim  | 
 | 30 | +        self.z_dim = z_dim  | 
 | 31 | +        self.batch_size = batch_size  | 
 | 32 | +        self.encoder_hidden_size = encoder_hidden_size  | 
 | 33 | +        self.decoder_hidden_size = decoder_hidden_size  | 
 | 34 | +        self.act_fn = act_fn  | 
 | 35 | +          | 
 | 36 | +        self._bulid_model()  | 
 | 37 | + | 
 | 38 | +    def _bulid_model(self):  | 
 | 39 | +        """The inner function to build the model"""  | 
 | 40 | +        # Input placeholder  | 
 | 41 | +        self.x = tf.placeholder(tf.float32, shape=[self.batch_size, self.input_dim])  | 
 | 42 | +        # The encoder: determine the mean and (log) variance of Gaussian distribution  | 
 | 43 | +        self.z_mean, self.z_log_sigma_sq = self._encoder(self.x)  | 
 | 44 | +        # Sampling from Gaussian distribution  | 
 | 45 | +        eps = tf.random_normal([self.batch_size, self.z_dim], mean=0.0, stddev=1.0)  | 
 | 46 | +        # z = mean + sigma*epsilon  | 
 | 47 | +        self.z = tf.add(self.z_mean, tf.mul(tf.sqrt(tf.exp(self.z_log_sigma_sq)), eps))  | 
 | 48 | + | 
 | 49 | +        # Decoder: determine the mean of Bernoulli distribution of reconstructed input  | 
 | 50 | +        self.x_reconstr_mean = self._decoder(self.z)  | 
 | 51 | +          | 
 | 52 | +        # Compute the loss  | 
 | 53 | +        with tf.name_scope("loss"):  | 
 | 54 | +            # The reconstruction loss: cross entropy  | 
 | 55 | +            reconstr_loss = -tf.reduce_sum(self.x * tf.log(1e-10 + self.x_reconstr_mean) + \  | 
 | 56 | +                            (1.0 - self.x) * tf.log(1e-10 + 1.0 - self.x_reconstr_mean), axis=1)  | 
 | 57 | +            # The latent loss: KL divergence  | 
 | 58 | +            latent_loss = -0.5 * tf.reduce_sum(1.0 + self.z_log_sigma_sq - tf.square(self.z_mean) - \  | 
 | 59 | +                                    tf.exp(self.z_log_sigma_sq), axis=1)  | 
 | 60 | +            # Average over the batch  | 
 | 61 | +            self.cost = tf.reduce_mean(reconstr_loss + latent_loss)  | 
 | 62 | +          | 
 | 63 | +        # The optimizer  | 
 | 64 | +        self.lr = tf.Variable(0.001, trainable=False)  | 
 | 65 | +        vars = tf.trainable_variables()  | 
 | 66 | +        self.train_op = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.cost, var_list=vars)  | 
 | 67 | +          | 
 | 68 | +    def _encoder(self, x, name="encoder"):  | 
 | 69 | +        """Encoder"""  | 
 | 70 | +        with tf.variable_scope(name):  | 
 | 71 | +            n_in = self.input_dim  | 
 | 72 | +            for i, s in enumerate(self.encoder_hidden_size):  | 
 | 73 | +                w, b = self._get_vars(n_in, s, name="h{0}".format(i))  | 
 | 74 | +                if i == 0:  | 
 | 75 | +                    h = self.act_fn(tf.nn.xw_plus_b(x, w, b))  | 
 | 76 | +                else:  | 
 | 77 | +                    h = self.act_fn(tf.nn.xw_plus_b(h, w, b))  | 
 | 78 | +                n_in = s  | 
 | 79 | +            w, b = self._get_vars(n_in, self.z_dim, name="out_mean")  | 
 | 80 | +            z_mean = tf.nn.xw_plus_b(h, w, b)  | 
 | 81 | +            w, b = self._get_vars(n_in, self.z_dim, name="out_log_sigma")  | 
 | 82 | +            z_log_sigma_sq = tf.nn.xw_plus_b(h, w, b)  | 
 | 83 | +            return z_mean, z_log_sigma_sq  | 
 | 84 | +          | 
 | 85 | +    def _decoder(self, z, name="decoder"):  | 
 | 86 | +        """Decoder"""  | 
 | 87 | +        with tf.variable_scope(name):  | 
 | 88 | +            n_in = self.z_dim  | 
 | 89 | +            for i, s in enumerate(self.decoder_hidden_size):  | 
 | 90 | +                w, b = self._get_vars(n_in, s, name="h{0}".format(i))  | 
 | 91 | +                if i == 0:  | 
 | 92 | +                    h = self.act_fn(tf.nn.xw_plus_b(z, w, b))  | 
 | 93 | +                else:  | 
 | 94 | +                    h = self.act_fn(tf.nn.xw_plus_b(h, w, b))  | 
 | 95 | +                n_in = s  | 
 | 96 | +            # Use sigmoid for Bernoulli distribution  | 
 | 97 | +            w, b = self._get_vars(n_in, self.input_dim, name="out_mean")  | 
 | 98 | +            x_reconstr_mean = tf.nn.sigmoid(tf.nn.xw_plus_b(h, w, b))  | 
 | 99 | +            return x_reconstr_mean  | 
 | 100 | + | 
 | 101 | +    def _get_vars(self, n_in, n_out, name=""):  | 
 | 102 | +        """  | 
 | 103 | +        Create weight and bias variables   | 
 | 104 | +        """  | 
 | 105 | +        with tf.variable_scope(name):  | 
 | 106 | +            w = tf.get_variable("w", [n_in, n_out], initializer=tf.contrib.layers.xavier_initializer())  | 
 | 107 | +            b = tf.get_variable("b", [n_out,], initializer=tf.constant_initializer(0.1))  | 
 | 108 | +            return w, b  | 
 | 109 | + | 
 | 110 | +if __name__ == "__main__":  | 
 | 111 | +    n_epochs = 30  | 
 | 112 | +    lr = 0.001  | 
 | 113 | +    batch_size = 100  | 
 | 114 | +    display_every = 1  | 
 | 115 | + | 
 | 116 | +    path = sys.path[0]  | 
 | 117 | +    mnist = read_data_sets("MNIST_data/", one_hot=True)  | 
 | 118 | +    with tf.Session() as sess:  | 
 | 119 | +        vae = VAE(input_dim=784, z_dim=2, batch_size=batch_size, encoder_hidden_size=[500, 500],  | 
 | 120 | +                    decoder_hidden_size=[500, 500], act_fn=tf.nn.softplus)  | 
 | 121 | +        sess.run(tf.global_variables_initializer())  | 
 | 122 | +        saver = tf.train.Saver()  | 
 | 123 | +        #saver.restore(sess, save_path=path+"/model/model.ckpt")  | 
 | 124 | +        # Start training  | 
 | 125 | +        print("Start training...")  | 
 | 126 | +        total_batch = int(mnist.train.num_examples/batch_size)  | 
 | 127 | +        for epoch in range(n_epochs):  | 
 | 128 | +            avg_cost = 0.0  | 
 | 129 | +            # For each batch   | 
 | 130 | +            for i in range(total_batch):  | 
 | 131 | +                batch_xs, _ = mnist.train.next_batch(batch_size)  | 
 | 132 | +                c, _ = sess.run([vae.cost, vae.train_op], feed_dict={vae.x: batch_xs})  | 
 | 133 | +                avg_cost += c/total_batch  | 
 | 134 | +            if epoch % display_every == 0:  | 
 | 135 | +                save_path = saver.save(sess, path+"/model/model.ckpt")  | 
 | 136 | +                #print("\tModel saved in file: {0}".format(save_path))  | 
 | 137 | +                print("\tEpoch {0}, cost {1}".format(epoch, avg_cost))  | 
 | 138 | +          | 
 | 139 | +        # Sampling  | 
 | 140 | +        x_sample, _ = mnist.test.next_batch(batch_size)  | 
 | 141 | +        x_reconstr = sess.run(vae.x_reconstr_mean, feed_dict={vae.x: x_sample})  | 
 | 142 | +        plt.figure(figsize=(8, 12))  | 
 | 143 | +        for i in range(5):  | 
 | 144 | +            plt.subplot(5, 2, 2*i + 1)  | 
 | 145 | +            plt.imshow(np.reshape(x_sample[i],(28, 28)), vmin=0, vmax=1, cmap="gray")  | 
 | 146 | +            plt.title("Test input")  | 
 | 147 | +            plt.colorbar()  | 
 | 148 | +            plt.subplot(5, 2, 2*i + 2)  | 
 | 149 | +            plt.imshow(np.reshape(x_reconstr[i], [28, 28]), vmin=0, vmax=1, cmap="gray")  | 
 | 150 | +            plt.title("Reconstruction")  | 
 | 151 | +            plt.colorbar()  | 
 | 152 | +        plt.tight_layout()  | 
 | 153 | +        plt.savefig(path+"/results/img_epoch{0}.jpg".format(n_epochs))  | 
 | 154 | +        plt.show()  | 
 | 155 | + | 
 | 156 | +        # Random sampling  | 
 | 157 | +        nx, ny = 20, 20  | 
 | 158 | +        xs = np.linspace(-3, 3, nx)  | 
 | 159 | +        ys = np.linspace(-3, 3, ny)  | 
 | 160 | +        xs, ys = np.meshgrid(xs, ys)  | 
 | 161 | +        xs = np.reshape(xs, [-1, 1])  | 
 | 162 | +        ys = np.reshape(ys, [-1, 1])  | 
 | 163 | +        zs = np.concatenate((xs, ys), axis=1)  | 
 | 164 | + | 
 | 165 | +        canvas = np.zeros((28*ny, 28*nx))  | 
 | 166 | +        for i in range(4):  | 
 | 167 | +            z_mu = zs[batch_size*i:batch_size*(i+1), :]  | 
 | 168 | +            x_mean = sess.run(vae.x_reconstr_mean, feed_dict={vae.z: z_mu})  | 
 | 169 | +            canvas[(ny-(i+1)*5)*28:(ny-i*5)*28] = x_mean.reshape(-1, 28*nx)[::-1]  | 
 | 170 | +          | 
 | 171 | +        plt.figure(figsize=(8, 10))  | 
 | 172 | +        plt.imshow(canvas, origin="upper", vmin=0, vmax=1, interpolation='none', cmap='gray')  | 
 | 173 | +        plt.tight_layout()  | 
 | 174 | +        plt.savefig(path+"/results/rand_img_epoch{0}.jpg".format(n_epochs))  | 
 | 175 | +        plt.show()  | 
 | 176 | + | 
 | 177 | + | 
 | 178 | +          | 
 | 179 | + | 
 | 180 | +      | 
0 commit comments