Skip to content

Commit b9fa0c6

Browse files
authored
Merge pull request #17 from hainingbaby/master
add evaluate_all_image()
2 parents 9be1473 + fe01370 commit b9fa0c6

File tree

1 file changed

+53
-2
lines changed

1 file changed

+53
-2
lines changed

01 cats vs dogs/training.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,59 @@ def run_training():
161161
# print('This is a cat with possibility %.6f' %prediction[:, 0])
162162
# else:
163163
# print('This is a dog with possibility %.6f' %prediction[:, 1])
164-
165-
164+
#
165+
def evaluate_all_image():
166+
'''
167+
Test all image against the saved models and parameters.
168+
Return global accuracy of test_image_set
169+
##############################################
170+
##Notice that test image must has label to compare the prediction and real
171+
##############################################
172+
'''
173+
# you need to change the directories to yours.
174+
test_dir = '/home/kevin/tensorflow/cats_vs_dogs/data/test/'
175+
N_CLASSES = 2
176+
print('-------------------------')
177+
test, test_label = input_data.get_files(test_dir)
178+
BATCH_SIZE = len(test)
179+
print('There are %d test images totally..' % BATCH_SIZE)
180+
print('-------------------------')
181+
test_batch, test_label_batch = input_data.get_batch(test,
182+
test_label,
183+
IMG_W,
184+
IMG_H,
185+
BATCH_SIZE,
186+
CAPACITY)
187+
188+
logits = model.inference(test_batch, BATCH_SIZE, N_CLASSES)
189+
testloss = model.losses(logits, test_label_batch)
190+
testacc = model.evaluation(logits, test_label_batch)
191+
192+
logs_train_dir = '/home/kevin/tensorflow/cats_vs_dogs/logs/train/'
193+
saver = tf.train.Saver()
194+
195+
with tf.Session() as sess:
196+
print("Reading checkpoints...")
197+
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
198+
if ckpt and ckpt.model_checkpoint_path:
199+
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
200+
saver.restore(sess, ckpt.model_checkpoint_path)
201+
print('Loading success, global_step is %s' % global_step)
202+
else:
203+
print('No checkpoint file found')
204+
print('-------------------------')
205+
coord = tf.train.Coordinator()
206+
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
207+
test_loss,test_acc = sess.run([testloss,testacc])
208+
print('The model\'s loss is %.2f' %test_loss)
209+
correct = int(BATCH_SIZE*test_acc)
210+
print('Correct : %d' % correct)
211+
print('Wrong : %d' % (BATCH_SIZE - correct))
212+
print('The accuracy in test images are %.2f%%' %(test_acc*100.0))
213+
coord.request_stop()
214+
coord.join(threads)
215+
sess.close()
216+
166217
#%%
167218

168219

0 commit comments

Comments
 (0)