33author: Ye Hu
442016/12/18
55"""
6+ import os
67import timeit
78import numpy as np
89import tensorflow as tf
10+ from PIL import Image
11+ from utils import tile_raster_images
12+ import input_data
913from rbm import RBM
1014
1115
@@ -56,39 +60,97 @@ def free_energy(self, v_sample):
5660 hidden_term = tf .reduce_sum (tf .log (1.0 + tf .exp (wx_b )), axis = 1 )
5761 return - hidden_term + vbias_term
5862
63+ def get_reconstruction_cost (self ):
64+ """Compute the mse of the original input and the reconstruction"""
65+ activation_h = self .propup (self .input )
66+ activation_v = self .propdown (activation_h )
67+ mse = tf .reduce_mean (tf .reduce_sum (tf .square (self .input - activation_v ), axis = 1 ))
68+ return mse
69+
70+
5971
6072if __name__ == "__main__" :
61- data = np .random .randn (1000 , 6 )
62- x = tf .placeholder (tf .float32 , shape = [None , 6 ])
63-
64- gbrbm = GBRBM (x , n_visiable = 6 , n_hidden = 5 )
65-
66- learning_rate = 0.1
67- k = 1
68- batch_size = 20
69- n_epochs = 10
70-
71- cost = gbrbm .get_reconstruction_cost ()
73+ # mnist examples
74+ mnist = input_data .read_data_sets ("MNIST_data/" , one_hot = True )
75+ # define input
76+ x = tf .placeholder (tf .float32 , shape = [None , 784 ])
77+ # set random_seed
78+ tf .set_random_seed (seed = 99999 )
79+ np .random .seed (123 )
80+ # the rbm model
81+ n_visiable , n_hidden = 784 , 500
82+ rbm = GBRBM (x , n_visiable = n_visiable , n_hidden = n_hidden )
83+
84+ learning_rate = 0.01
85+ batch_size = 50
86+ cost = rbm .get_reconstruction_cost ()
7287 # Create the persistent variable
7388 #persistent_chain = tf.Variable(tf.zeros([batch_size, n_hidden]), dtype=tf.float32)
7489 persistent_chain = None
75- train_ops = gbrbm .get_train_ops (learning_rate = learning_rate , k = 1 , persistent = persistent_chain )
90+ train_ops = rbm .get_train_ops (learning_rate = learning_rate , k = 1 , persistent = persistent_chain )
7691 init = tf .global_variables_initializer ()
7792
78- sess = tf .Session ()
79- sess .run (init )
80- for epoch in range (n_epochs ):
81- avg_cost = 0.0
82- for i in range (len (data )// batch_size ):
83- sess .run (train_ops , feed_dict = {x : data [i * batch_size :(i + 1 )* batch_size ]})
84- avg_cost += sess .run (cost , feed_dict = {x : data [i * batch_size :(i + 1 )* batch_size ]})/ batch_size
85- print (avg_cost )
86-
87- # test
88- v = np .random .randn (10 , 6 )
89- print (v )
93+ output_folder = "rbm_plots"
94+ if not os .path .isdir (output_folder ):
95+ os .makedirs (output_folder )
96+ os .chdir (output_folder )
9097
91- preds = sess .run (gbrbm .reconstruct (x ), feed_dict = {x : v })
92- print (preds )
98+ training_epochs = 15
99+ display_step = 1
100+ print ("Start training..." )
101+
102+ with tf .Session () as sess :
103+ start_time = timeit .default_timer ()
104+ sess .run (init )
105+ for epoch in range (training_epochs ):
106+ avg_cost = 0.0
107+ batch_num = int (mnist .train .num_examples / batch_size )
108+ for i in range (batch_num ):
109+ x_batch , _ = mnist .train .next_batch (batch_size )
110+ # 训练
111+ sess .run (train_ops , feed_dict = {x : x_batch })
112+ # 计算cost
113+ avg_cost += sess .run (cost , feed_dict = {x : x_batch ,}) / batch_num
114+ # 输出
115+ if epoch % display_step == 0 :
116+ print ("Epoch {0} cost: {1}" .format (epoch , avg_cost ))
117+ # Construct image from the weight matrix
118+ image = Image .fromarray (
119+ tile_raster_images (
120+ X = sess .run (tf .transpose (rbm .W )),
121+ img_shape = (28 , 28 ),
122+ tile_shape = (10 , 10 ),
123+ tile_spacing = (1 , 1 )))
124+ image .save ("test_filters_at_epoch_{0}.png" .format (epoch ))
93125
94-
126+ end_time = timeit .default_timer ()
127+ training_time = end_time - start_time
128+ print ("Finished!" )
129+ print (" The training ran for {0} minutes." .format (training_time / 60 ,))
130+
131+ # Randomly select the 'n_chains' examples
132+ n_chains = 20
133+ n_batch = 10
134+ n_samples = n_batch * 2
135+ number_test_examples = mnist .test .num_examples
136+ test_indexs = np .random .randint (number_test_examples - n_chains * n_batch )
137+ test_samples = mnist .test .images [test_indexs :test_indexs + n_chains * n_batch ]
138+ image_data = np .zeros ((29 * (n_samples + 1 )+ 1 , 29 * (n_chains )- 1 ),
139+ dtype = "uint8" )
140+ # Add the original images
141+ for i in range (n_batch ):
142+ image_data [2 * i * 29 :2 * i * 29 + 28 ,:] = tile_raster_images (X = test_samples [i * n_batch :(i + 1 )* n_chains ],
143+ img_shape = (28 , 28 ),
144+ tile_shape = (1 , n_chains ),
145+ tile_spacing = (1 , 1 ))
146+ samples = sess .run (rbm .reconstruct (x ), feed_dict = {x :test_samples [i * n_batch :(i + 1 )* n_chains ]})
147+ image_data [(2 * i + 1 )* 29 :(2 * i + 1 )* 29 + 28 ,:] = tile_raster_images (X = samples ,
148+ img_shape = (28 , 28 ),
149+ tile_shape = (1 , n_chains ),
150+ tile_spacing = (1 , 1 ))
151+
152+ image = Image .fromarray (image_data )
153+ image .save ("original_and_reconstruct.png" )
154+
155+
156+
0 commit comments