Skip to content

Commit 79a65c0

Browse files
committed
Merge pull request memoakten#6 from alesaccoia/master
Creates train_dir on the first run, then tries restores just if the f…
2 parents 9e16866 + 787bc56 commit 79a65c0

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

example-mnist/bin/py/mnist_deep.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737

3838
out_path = '../data/model-deep'
3939
out_fname = 'mnist-deep'
40-
train_file ='ckpt-deep/'+out_fname + '.ckpt'
40+
train_dir = './ckpt-deep/'
41+
train_file = train_dir + out_fname + '.ckpt'
4142

4243
# Get data
4344
mnist = input_data.read_data_sets("training_data/", one_hot=True)
@@ -130,8 +131,11 @@ def save(path, fname, sess):
130131
# Can't loads this in C++ (shame), only useful for contining training
131132
saver = tf.train.Saver()
132133

133-
#comment out this line to start training from the beginning
134-
saver.restore(sess, train_file);
134+
if os.path.exists(train_dir):
135+
if os.path.exists(train_file):
136+
saver.restore(sess, train_file);
137+
else:
138+
os.makedirs(train_dir);
135139

136140
for i in range(5000):
137141
batch = mnist.train.next_batch(50)

0 commit comments

Comments
 (0)