|
1 | | -# -*- coding: utf-8 -*- |
2 | | - |
3 | 1 | """ Auto Encoder Example. |
4 | | -Using an auto encoder on MNIST handwritten digits. |
| 2 | +
|
| 3 | +Build a 2 layers auto-encoder with TensorFlow to compress images to a |
| 4 | +lower latent space and then reconstruct them. |
| 5 | +
|
5 | 6 | References: |
6 | 7 | Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based |
7 | 8 | learning applied to document recognition." Proceedings of the IEEE, |
8 | 9 | 86(11):2278-2324, November 1998. |
| 10 | +
|
9 | 11 | Links: |
10 | 12 | [MNIST Dataset] http://yann.lecun.com/exdb/mnist/ |
| 13 | +
|
| 14 | +Author: Aymeric Damien |
| 15 | +Project: https://github.com/aymericdamien/TensorFlow-Examples/ |
11 | 16 | """ |
12 | 17 | from __future__ import division, print_function, absolute_import |
13 | 18 |
|
|
17 | 22 |
|
18 | 23 | # Import MNIST data |
19 | 24 | from tensorflow.examples.tutorials.mnist import input_data |
20 | | -mnist = input_data.read_data_sets("MNIST_data", one_hot=True) |
| 25 | +mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) |
21 | 26 |
|
22 | | -# Parameters |
| 27 | +# Training Parameters |
23 | 28 | learning_rate = 0.01 |
24 | | -training_epochs = 20 |
| 29 | +num_steps = 30000 |
25 | 30 | batch_size = 256 |
26 | | -display_step = 1 |
| 31 | + |
| 32 | +display_step = 1000 |
27 | 33 | examples_to_show = 10 |
28 | 34 |
|
29 | 35 | # Network Parameters |
30 | | -n_hidden_1 = 256 # 1st layer num features |
31 | | -n_hidden_2 = 128 # 2nd layer num features |
32 | | -n_input = 784 # MNIST data input (img shape: 28*28) |
| 36 | +num_hidden_1 = 256 # 1st layer num features |
| 37 | +num_hidden_2 = 128 # 2nd layer num features (the latent dim) |
| 38 | +num_input = 784 # MNIST data input (img shape: 28*28) |
33 | 39 |
|
34 | 40 | # tf Graph input (only pictures) |
35 | | -X = tf.placeholder("float", [None, n_input]) |
| 41 | +X = tf.placeholder("float", [None, num_input]) |
36 | 42 |
|
37 | 43 | weights = { |
38 | | - 'encoder_h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])), |
39 | | - 'encoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])), |
40 | | - 'decoder_h1': tf.Variable(tf.random_normal([n_hidden_2, n_hidden_1])), |
41 | | - 'decoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_input])), |
| 44 | + 'encoder_h1': tf.Variable(tf.random_normal([num_input, num_hidden_1])), |
| 45 | + 'encoder_h2': tf.Variable(tf.random_normal([num_hidden_1, num_hidden_2])), |
| 46 | + 'decoder_h1': tf.Variable(tf.random_normal([num_hidden_2, num_hidden_1])), |
| 47 | + 'decoder_h2': tf.Variable(tf.random_normal([num_hidden_1, num_input])), |
42 | 48 | } |
43 | 49 | biases = { |
44 | | - 'encoder_b1': tf.Variable(tf.random_normal([n_hidden_1])), |
45 | | - 'encoder_b2': tf.Variable(tf.random_normal([n_hidden_2])), |
46 | | - 'decoder_b1': tf.Variable(tf.random_normal([n_hidden_1])), |
47 | | - 'decoder_b2': tf.Variable(tf.random_normal([n_input])), |
| 50 | + 'encoder_b1': tf.Variable(tf.random_normal([num_hidden_1])), |
| 51 | + 'encoder_b2': tf.Variable(tf.random_normal([num_hidden_2])), |
| 52 | + 'decoder_b1': tf.Variable(tf.random_normal([num_hidden_1])), |
| 53 | + 'decoder_b2': tf.Variable(tf.random_normal([num_input])), |
48 | 54 | } |
49 | 55 |
|
50 | | - |
51 | 56 | # Building the encoder |
52 | 57 | def encoder(x): |
53 | 58 | # Encoder Hidden layer with sigmoid activation #1 |
@@ -79,38 +84,59 @@ def decoder(x): |
79 | 84 | y_true = X |
80 | 85 |
|
81 | 86 | # Define loss and optimizer, minimize the squared error |
82 | | -cost = tf.reduce_mean(tf.pow(y_true - y_pred, 2)) |
83 | | -optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(cost) |
| 87 | +loss = tf.reduce_mean(tf.pow(y_true - y_pred, 2)) |
| 88 | +optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(loss) |
84 | 89 |
|
85 | | -# Initializing the variables |
| 90 | +# Initialize the variables (i.e. assign their default value) |
86 | 91 | init = tf.global_variables_initializer() |
87 | 92 |
|
88 | | -# Launch the graph |
| 93 | +# Start Training |
| 94 | +# Start a new TF session |
89 | 95 | with tf.Session() as sess: |
| 96 | + |
| 97 | + # Run the initializer |
90 | 98 | sess.run(init) |
91 | | - total_batch = int(mnist.train.num_examples/batch_size) |
92 | | - # Training cycle |
93 | | - for epoch in range(training_epochs): |
94 | | - # Loop over all batches |
95 | | - for i in range(total_batch): |
96 | | - batch_xs, batch_ys = mnist.train.next_batch(batch_size) |
97 | | - # Run optimization op (backprop) and cost op (to get loss value) |
98 | | - _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs}) |
99 | | - # Display logs per epoch step |
100 | | - if epoch % display_step == 0: |
101 | | - print("Epoch:", '%04d' % (epoch+1), |
102 | | - "cost=", "{:.9f}".format(c)) |
103 | | - |
104 | | - print("Optimization Finished!") |
105 | | - |
106 | | - # Applying encode and decode over test set |
107 | | - encode_decode = sess.run( |
108 | | - y_pred, feed_dict={X: mnist.test.images[:examples_to_show]}) |
109 | | - # Compare original images with their reconstructions |
110 | | - f, a = plt.subplots(2, 10, figsize=(10, 2)) |
111 | | - for i in range(examples_to_show): |
112 | | - a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28))) |
113 | | - a[1][i].imshow(np.reshape(encode_decode[i], (28, 28))) |
114 | | - f.show() |
115 | | - plt.draw() |
116 | | - plt.waitforbuttonpress() |
| 99 | + |
| 100 | + # Training |
| 101 | + for i in range(1, num_steps+1): |
| 102 | + # Prepare Data |
| 103 | + # Get the next batch of MNIST data (only images are needed, not labels) |
| 104 | + batch_x, _ = mnist.train.next_batch(batch_size) |
| 105 | + |
| 106 | + # Run optimization op (backprop) and cost op (to get loss value) |
| 107 | + _, l = sess.run([optimizer, loss], feed_dict={X: batch_x}) |
| 108 | + # Display logs per step |
| 109 | + if i % display_step == 0 or i == 1: |
| 110 | + print('Step %i: Minibatch Loss: %f' % (i, l)) |
| 111 | + |
| 112 | + # Testing |
| 113 | + # Encode and decode images from test set and visualize their reconstruction. |
| 114 | + n = 4 |
| 115 | + canvas_orig = np.empty((28 * n, 28 * n)) |
| 116 | + canvas_recon = np.empty((28 * n, 28 * n)) |
| 117 | + for i in range(n): |
| 118 | + # MNIST test set |
| 119 | + batch_x, _ = mnist.test.next_batch(n) |
| 120 | + # Encode and decode the digit image |
| 121 | + g = sess.run(decoder_op, feed_dict={X: batch_x}) |
| 122 | + |
| 123 | + # Display original images |
| 124 | + for j in range(n): |
| 125 | + # Draw the original digits |
| 126 | + canvas_orig[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = \ |
| 127 | + batch_x[j].reshape([28, 28]) |
| 128 | + # Display reconstructed images |
| 129 | + for j in range(n): |
| 130 | + # Draw the reconstructed digits |
| 131 | + canvas_recon[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = \ |
| 132 | + g[j].reshape([28, 28]) |
| 133 | + |
| 134 | + print("Original Images") |
| 135 | + plt.figure(figsize=(n, n)) |
| 136 | + plt.imshow(canvas_orig, origin="upper", cmap="gray") |
| 137 | + plt.show() |
| 138 | + |
| 139 | + print("Reconstructed Images") |
| 140 | + plt.figure(figsize=(n, n)) |
| 141 | + plt.imshow(canvas_recon, origin="upper", cmap="gray") |
| 142 | + plt.show() |
0 commit comments