@@ -32,10 +32,14 @@ def plot_his(inputs, inputs_norm):
3232 for i , input in enumerate (all_inputs ):
3333 plt .subplot (2 , len (all_inputs ), j * len (all_inputs )+ (i + 1 ))
3434 plt .cla ()
35- plt .hist (input .ravel (), bins = 15 , range = (- 1 , 1 ), color = '#FF5733' )
35+ if i == 0 :
36+ the_range = (- 7 , 10 )
37+ else :
38+ the_range = (- 1 , 1 )
39+ plt .hist (input .ravel (), bins = 15 , range = the_range , color = '#FF5733' )
3640 plt .yticks (())
3741 if j == 1 :
38- plt .xticks (( - 1 , 0 , 1 ) )
42+ plt .xticks (the_range )
3943 else :
4044 plt .xticks (())
4145 ax = plt .gca ()
@@ -81,6 +85,18 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
8185 return outputs
8286
8387 fix_seed (1 )
88+
89+ if norm :
90+ # BN for the first input
91+ fc_mean , fc_var = tf .nn .moments (
92+ xs ,
93+ axes = [0 ],
94+ )
95+ scale = tf .Variable (tf .ones ([1 ]))
96+ shift = tf .Variable (tf .zeros ([1 ]))
97+ epsilon = 0.001
98+ xs = tf .nn .batch_normalization (xs , fc_mean , fc_var , shift , scale , epsilon )
99+
84100 # record inputs for every layer
85101 layers_inputs = [xs ]
86102
@@ -137,8 +153,8 @@ def add_layer(inputs, in_size, out_size, activation_function=None, norm=False):
137153 all_inputs , all_inputs_norm = sess .run ([layers_inputs , layers_inputs_norm ], feed_dict = {xs : x_data , ys : y_data })
138154 plot_his (all_inputs , all_inputs_norm )
139155
140- sess .run (train_op , feed_dict = {xs : x_data , ys : y_data })
141- sess . run ( train_op_norm , feed_dict = { xs : x_data , ys : y_data })
156+ sess .run ([ train_op , train_op_norm ] , feed_dict = {xs : x_data , ys : y_data })
157+
142158 if i % record_step == 0 :
143159 # record cost
144160 cost_his .append (sess .run (cost , feed_dict = {xs : x_data , ys : y_data }))
0 commit comments