1
1
2
+ import os
2
3
import array
3
- import emlearn_cnn
4
4
import time
5
5
import gc
6
6
7
+ import emlearn_cnn
8
+
7
9
MODEL = 'mnist_cnn.tmdl'
8
- TEST_DATA_DIR = 'data/ '
10
+ TEST_DATA_DIR = 'test_data '
9
11
10
12
def argmax (arr ):
11
13
idx_max = 0
@@ -30,6 +32,27 @@ def print_2d_buffer(arr, rowstride):
30
32
gc .collect ()
31
33
print ('\n ' )
32
34
35
+ def load_images_from_directory (path ):
36
+ sep = '/'
37
+
38
+ for filename in os .listdir (path ):
39
+ # TODO: support standard image formats, like .bmp/.png/.jpeg
40
+ if not filename .endswith ('.bin' ):
41
+ continue
42
+
43
+ # Find the label (if any). The last part, X_label.format
44
+ label = None
45
+ basename = filename .split ('.' )[0 ]
46
+ tok = basename .split ('_' )
47
+ if len (tok ) > 2 :
48
+ label = tok [- 1 ]
49
+
50
+ data_path = path + sep + filename
51
+ with open (data_path , 'rb' ) as f :
52
+ img = array .array ('B' , f .read ())
53
+
54
+ yield img , label
55
+
33
56
def test_cnn_mnist ():
34
57
35
58
# load model
@@ -42,22 +65,28 @@ def test_cnn_mnist():
42
65
probabilities = array .array ('f' , (- 1 for _ in range (out_length )))
43
66
44
67
# run on some test data
45
- for class_no in range (0 , 10 ):
46
- data_path = TEST_DATA_DIR + 'mnist_example_{0:d}.bin' .format (class_no )
47
- #print('open', data_path)
48
- with open (data_path , 'rb' ) as f :
49
- img = array .array ('B' , f .read ())
50
-
51
- print_2d_buffer (img , 28 )
52
-
53
- run_start = time .ticks_us ()
54
- model .run (img , probabilities )
55
- out = argmax (probabilities )
56
- run_duration = time .ticks_diff (time .ticks_us (), run_start ) / 1000.0 # ms
57
-
58
- print ('mnist-example-check' , class_no , out , class_no == out , run_duration )
68
+ n_correct = 0
69
+ n_total = 0
70
+ for img , label in load_images_from_directory (TEST_DATA_DIR ):
71
+ class_no = int (label ) # mnist class labels are digits
72
+
73
+ #print_2d_buffer(img, 28)
74
+
75
+ run_start = time .ticks_us ()
76
+ model .run (img , probabilities )
77
+ out = argmax (probabilities )
78
+ run_duration = time .ticks_diff (time .ticks_us (), run_start ) / 1000.0 # ms
79
+ correct = class_no == out
80
+ n_total += 1
81
+ if correct :
82
+ n_correct += 1
83
+
84
+ print ('mnist-example-check' , class_no , '=' , out , correct , round (run_duration , 3 ))
59
85
60
86
gc .collect ()
61
87
88
+ accuracy = n_correct / n_total
89
+ print ('mnist-example-done' , n_correct , '/' , n_total , round (accuracy * 100 , ), '%' )
90
+
62
91
if __name__ == '__main__' :
63
92
test_cnn_mnist ()
0 commit comments