Skip to content

Commit 815f303

Browse files
authored
Update vae_mnist.py
1 parent 6d6beb0 commit 815f303

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

examples/VAE/vae_mnist.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,17 @@ def _get_vars(self, n_in, n_out, name=""):
163163
zs = np.concatenate((xs, ys), axis=1)
164164

165165
canvas = np.zeros((28*ny, 28*nx))
166+
xs_recon = np.zeros((batch_size*4, 28*28))
166167
for i in range(4):
167168
z_mu = zs[batch_size*i:batch_size*(i+1), :]
168169
x_mean = sess.run(vae.x_reconstr_mean, feed_dict={vae.z: z_mu})
169-
canvas[(ny-(i+1)*5)*28:(ny-i*5)*28] = x_mean.reshape(-1, 28*nx)[::-1]
170+
xs_recon[i*batch_size:(i+1)*batch_size] = x_mean
171+
172+
n = 0
173+
for i in range(nx):
174+
for j in range(ny):
175+
canvas[(ny-i-1)*28:(ny-i)*28, j*28:(j+1)*28] = xs_recon[n].reshape(28, 28)
176+
n = n + 1
170177

171178
plt.figure(figsize=(8, 10))
172179
plt.imshow(canvas, origin="upper", vmin=0, vmax=1, interpolation='none', cmap='gray')

0 commit comments

Comments
 (0)