Skip to content

Commit cc2cb6a

Browse files
committed
mnist_cnn: More test samples, log overall score
1 parent 3d771b7 commit cc2cb6a

File tree

3 files changed

+54
-26
lines changed

3 files changed

+54
-26
lines changed

examples/mnist_cnn/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ mpremote mip install https://emlearn.github.io/emlearn-micropython/builds/master
5555

5656
```console
5757
mpremote cp mnist_cnn.tmdl :
58-
mpremote cp -r data/ :
58+
mpremote cp -r test_data/ :
5959
mpremote run mnist_cnn_run.py
6060
```
6161

examples/mnist_cnn/mnist_cnn_run.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11

2+
import os
23
import array
3-
import emlearn_cnn
44
import time
55
import gc
66

7+
import emlearn_cnn
8+
79
MODEL = 'mnist_cnn.tmdl'
8-
TEST_DATA_DIR = 'data/'
10+
TEST_DATA_DIR = 'test_data'
911

1012
def argmax(arr):
1113
idx_max = 0
@@ -30,6 +32,27 @@ def print_2d_buffer(arr, rowstride):
3032
gc.collect()
3133
print('\n')
3234

35+
def load_images_from_directory(path):
36+
sep = '/'
37+
38+
for filename in os.listdir(path):
39+
# TODO: support standard image formats, like .bmp/.png/.jpeg
40+
if not filename.endswith('.bin'):
41+
continue
42+
43+
# Find the label (if any). The last part, X_label.format
44+
label = None
45+
basename = filename.split('.')[0]
46+
tok = basename.split('_')
47+
if len(tok) > 2:
48+
label = tok[-1]
49+
50+
data_path = path + sep + filename
51+
with open(data_path, 'rb') as f:
52+
img = array.array('B', f.read())
53+
54+
yield img, label
55+
3356
def test_cnn_mnist():
3457

3558
# load model
@@ -42,22 +65,28 @@ def test_cnn_mnist():
4265
probabilities = array.array('f', (-1 for _ in range(out_length)))
4366

4467
# run on some test data
45-
for class_no in range(0, 10):
46-
data_path = TEST_DATA_DIR + 'mnist_example_{0:d}.bin'.format(class_no)
47-
#print('open', data_path)
48-
with open(data_path, 'rb') as f:
49-
img = array.array('B', f.read())
50-
51-
print_2d_buffer(img, 28)
52-
53-
run_start = time.ticks_us()
54-
model.run(img, probabilities)
55-
out = argmax(probabilities)
56-
run_duration = time.ticks_diff(time.ticks_us(), run_start) / 1000.0 # ms
57-
58-
print('mnist-example-check', class_no, out, class_no == out, run_duration)
68+
n_correct = 0
69+
n_total = 0
70+
for img, label in load_images_from_directory(TEST_DATA_DIR):
71+
class_no = int(label) # mnist class labels are digits
72+
73+
#print_2d_buffer(img, 28)
74+
75+
run_start = time.ticks_us()
76+
model.run(img, probabilities)
77+
out = argmax(probabilities)
78+
run_duration = time.ticks_diff(time.ticks_us(), run_start) / 1000.0 # ms
79+
correct = class_no == out
80+
n_total += 1
81+
if correct:
82+
n_correct += 1
83+
84+
print('mnist-example-check', class_no, '=', out, correct, round(run_duration, 3))
5985

6086
gc.collect()
6187

88+
accuracy = n_correct / n_total
89+
print('mnist-example-done', n_correct, '/', n_total, round(accuracy*100, ), '%')
90+
6291
if __name__ == '__main__':
6392
test_cnn_mnist()

examples/mnist_cnn/mnist_train.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def train_mnist(h5_file, epochs=10):
4040
(x_orig_train, y_orig_train), (x_orig_test, y_orig_test) = mnist.load_data()
4141
num_classes = 10
4242

43-
generate_test_files('test_data', x_orig_test, y_orig_test)
43+
TEST_DATA_DIR = 'test_data'
44+
generate_test_files(TEST_DATA_DIR, x_orig_test, y_orig_test)
45+
print('Wrote test data to', TEST_DATA_DIR)
4446

4547
x_train = x_orig_train
4648
x_test = x_orig_test
@@ -58,7 +60,7 @@ def train_mnist(h5_file, epochs=10):
5860

5961
model.save(h5_file)
6062

61-
def generate_test_files(out_dir, x, y):
63+
def generate_test_files(out_dir, x, y, samples_per_class=5):
6264

6365
if not os.path.exists(out_dir):
6466
os.makedirs(out_dir)
@@ -71,15 +73,12 @@ def generate_test_files(out_dir, x, y):
7173
# select one per class
7274
for class_no in classes:
7375
matches = (Y_classes == class_no)
74-
print('mm', matches.shape)
7576
x_matches = X_series[matches]
7677

77-
selected = x_matches.sample(n=1, random_state=1)
78-
for s in selected:
79-
print('ss', s.shape, s.dtype)
80-
print(s)
81-
out = os.path.join(out_dir, f'mnist_example_{class_no}.bin')
82-
data = s.tobytes(order='C')
78+
selected = x_matches.sample(n=samples_per_class, random_state=1)
79+
for i, sample in enumerate(selected):
80+
out = os.path.join(out_dir, f'mnist_example_{i}_{class_no}.bin')
81+
data = sample.tobytes(order='C')
8382

8483
assert len(data) == expect_bytes, (len(data), expect_bytes)
8584
with open(out, 'wb') as f:

0 commit comments

Comments
 (0)