@@ -161,8 +161,59 @@ def run_training():
161161#                print('This is a cat with possibility %.6f' %prediction[:, 0]) 
162162#            else: 
163163#                print('This is a dog with possibility %.6f' %prediction[:, 1]) 
164- 
165- 
164+ # 
165+ def  evaluate_all_image ():
166+ 	''' 
167+ 	Test all image against the saved models and parameters. 
168+ 	Return global accuracy of test_image_set 
169+ 	############################################## 
170+ 	##Notice that test image must has label to compare the prediction and real 
171+ 	############################################## 
172+ 	''' 
173+     # you need to change the directories to yours. 
174+ 	test_dir  =  '/home/kevin/tensorflow/cats_vs_dogs/data/test/' 
175+     N_CLASSES  =  2 
176+     print ('-------------------------' )
177+     test , test_label  =  input_data .get_files (test_dir )
178+     BATCH_SIZE  =  len (test )
179+     print ('There are %d test images totally..'  %  BATCH_SIZE )
180+     print ('-------------------------' )
181+     test_batch , test_label_batch  =  input_data .get_batch (test ,
182+                                                           test_label ,
183+                                                           IMG_W ,
184+                                                           IMG_H ,
185+                                                           BATCH_SIZE , 
186+                                                           CAPACITY )
187+ 	
188+ 	logits  =  model .inference (test_batch , BATCH_SIZE , N_CLASSES )
189+     testloss  =  model .losses (logits , test_label_batch ) 
190+     testacc  =  model .evaluation (logits , test_label_batch )
191+ 	
192+ 	logs_train_dir  =  '/home/kevin/tensorflow/cats_vs_dogs/logs/train/'               
193+     saver  =  tf .train .Saver ()
194+         
195+     with  tf .Session () as  sess :
196+         print ("Reading checkpoints..." )
197+         ckpt  =  tf .train .get_checkpoint_state (logs_train_dir )
198+         if  ckpt  and  ckpt .model_checkpoint_path :
199+             global_step  =  ckpt .model_checkpoint_path .split ('/' )[- 1 ].split ('-' )[- 1 ]
200+             saver .restore (sess , ckpt .model_checkpoint_path )
201+             print ('Loading success, global_step is %s'  %  global_step )
202+         else :
203+             print ('No checkpoint file found' )
204+         print ('-------------------------' )        
205+         coord  =  tf .train .Coordinator ()
206+         threads  =  tf .train .start_queue_runners (sess = sess , coord = coord )
207+         test_loss ,test_acc  =  sess .run ([testloss ,testacc ])
208+         print ('The model\' s loss is %.2f'  % test_loss )
209+         correct  =  int (BATCH_SIZE * test_acc )
210+         print ('Correct : %d'  %  correct )
211+         print ('Wrong : %d'  %  (BATCH_SIZE  -  correct ))
212+         print ('The accuracy in test images are %.2f%%'  % (test_acc * 100.0 ))
213+     coord .request_stop ()
214+     coord .join (threads )
215+     sess .close ()
216+  
166217#%% 
167218
168219
0 commit comments