Skip to content

Commit 05c25cb

Browse files
authored
Merge pull request emlearn#23 from emlearn/cnn-fix-output
cnn: Fix hardcoded outputs size leading to corrupted predictions
2 parents d4682d7 + fa81da7 commit 05c25cb

File tree

3 files changed

+117
-48
lines changed

3 files changed

+117
-48
lines changed

examples/mnist_cnn/mnist_cnn_run.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@
77
MODEL = 'mnist_cnn.tmdl'
88
TEST_DATA_DIR = 'data/'
99

10+
def argmax(arr):
11+
idx_max = 0
12+
value_max = arr[0]
13+
for i in range(1, len(arr)):
14+
if arr[i] > value_max:
15+
value_max = arr[i]
16+
idx_max = i
17+
18+
return idx_max
19+
1020
def print_2d_buffer(arr, rowstride):
1121

1222
rows = len(arr) // rowstride
@@ -28,6 +38,9 @@ def test_cnn_mnist():
2838
model_data = array.array('B', f.read())
2939
model = emlearn_cnn.new(model_data)
3040

41+
out_length = model.output_dimensions()[0]
42+
probabilities = array.array('f', (-1 for _ in range(out_length)))
43+
3144
# run on some test data
3245
for class_no in range(0, 10):
3346
data_path = TEST_DATA_DIR + 'mnist_example_{0:d}.bin'.format(class_no)
@@ -38,7 +51,8 @@ def test_cnn_mnist():
3851
print_2d_buffer(img, 28)
3952

4053
run_start = time.ticks_us()
41-
out = model.run(img)
54+
model.run(img, probabilities)
55+
out = argmax(probabilities)
4256
run_duration = time.ticks_diff(time.ticks_us(), run_start) / 1000.0 # ms
4357

4458
print('mnist-example-check', class_no, out, class_no == out, run_duration)

src/tinymaix_cnn/mod_cnn.c

Lines changed: 85 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -34,39 +34,31 @@ void NORETURN abort() {
3434
}
3535
#endif
3636

37-
38-
static tm_err_t layer_cb(tm_mdl_t* mdl, tml_head_t* lh)
39-
{
40-
#if 0
41-
//dump middle result
42-
int h = lh->out_dims[1];
43-
int w = lh->out_dims[2];
44-
int ch= lh->out_dims[3];
45-
mtype_t* output = TML_GET_OUTPUT(mdl, lh);
46-
return TM_OK;
47-
TM_PRINTF("Layer %d callback ========\n", mdl->layer_i);
48-
#if 1
49-
for(int y=0; y<h; y++){
50-
TM_PRINTF("[");
51-
for(int x=0; x<w; x++){
52-
TM_PRINTF("[");
53-
for(int c=0; c<ch; c++){
54-
#if TM_MDL_TYPE == TM_MDL_FP32
55-
TM_PRINTF("%.3f,", output[(y*w+x)*ch+c]);
56-
#else
57-
TM_PRINTF("%.3f,", TML_DEQUANT(lh,output[(y*w+x)*ch+c]));
58-
#endif
37+
// get model output shapes
38+
//mdl: model handle; in: input mat; out: output mat
39+
int TM_WEAK tm_get_outputs(tm_mdl_t* mdl, tm_mat_t* out, int out_length)
40+
{
41+
// NOTE: based on tm_run, but without actually executing
42+
int out_idx = 0;
43+
mdl->layer_body = mdl->b->layers_body;
44+
for(mdl->layer_i = 0; mdl->layer_i < mdl->b->layer_cnt; mdl->layer_i++){
45+
tml_head_t* h = (tml_head_t*)(mdl->layer_body);
46+
if(h->is_out) {
47+
if (out_idx < out_length) {
48+
memcpy((void*)(&out[out_idx]), (void*)(&(h->out_dims)), sizeof(uint16_t)*4);
49+
out_idx += 1;
50+
} else {
51+
return -1;
5952
}
60-
TM_PRINTF("],");
6153
}
62-
TM_PRINTF("],\n");
54+
mdl->layer_body += (h->size);
6355
}
64-
TM_PRINTF("\n");
65-
#endif
66-
return TM_OK;
67-
#else
56+
return out_idx;
57+
}
58+
59+
static tm_err_t layer_cb(tm_mdl_t* mdl, tml_head_t* lh)
60+
{
6861
return TM_OK;
69-
#endif
7062
}
7163

7264
#define DEBUG (1)
@@ -79,6 +71,7 @@ typedef struct _mp_obj_mod_cnn_t {
7971
tm_mat_t input;
8072
uint8_t *model_buffer;
8173
uint8_t *data_buffer;
74+
uint16_t out_dims[4];
8275
} mp_obj_mod_cnn_t;
8376

8477
mp_obj_full_type_t mod_cnn_type;
@@ -121,6 +114,25 @@ static mp_obj_t mod_cnn_new(mp_obj_t model_data_obj) {
121114
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("tm_load error"));
122115
}
123116

117+
// find model output shape
118+
o->out_dims[0] = 0;
119+
tm_mat_t outs[1];
120+
const int outputs = tm_get_outputs(model, outs, 1);
121+
if (outputs != 1) {
122+
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("only 1 output supported"));
123+
}
124+
memcpy((void*)(o->out_dims), (void*)(&(outs[0])), sizeof(uint16_t)*4);
125+
126+
if ((o->out_dims[0] != 1)) {
127+
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("output must be 1d"));
128+
}
129+
memcpy((void*)(o->out_dims), (void*)(&(outs[0])), sizeof(uint16_t)*4);
130+
131+
#if DEBUG
132+
mp_printf(&mp_plat_print, "cnn-new-done outs=%d out.dims=(%d,%d,%d,%d) \n",
133+
outputs, o->out_dims[0], o->out_dims[1], o->out_dims[2], o->out_dims[3]);
134+
#endif
135+
124136
return MP_OBJ_FROM_PTR(o);
125137
}
126138
static MP_DEFINE_CONST_FUN_OBJ_1(mod_cnn_new_obj, mod_cnn_new);
@@ -141,15 +153,15 @@ static MP_DEFINE_CONST_FUN_OBJ_1(mod_cnn_del_obj, mod_cnn_del);
141153

142154

143155
// Add a node to the tree
144-
static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj) {
156+
static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj, mp_obj_t output_obj) {
145157

146158
mp_obj_mod_cnn_t *o = MP_OBJ_TO_PTR(self_obj);
147159

148160
// Extract input
149161
mp_buffer_info_t bufinfo;
150162
mp_get_buffer_raise(input_obj, &bufinfo, MP_BUFFER_RW);
151163
if (bufinfo.typecode != 'B') {
152-
mp_raise_ValueError(MP_ERROR_TEXT("expecting float array"));
164+
mp_raise_ValueError(MP_ERROR_TEXT("expecting byte array"));
153165
}
154166
uint8_t *input_buffer = bufinfo.buf;
155167
const int input_length = bufinfo.len / sizeof(*input_buffer);
@@ -160,6 +172,21 @@ static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj) {
160172
mp_raise_ValueError(MP_ERROR_TEXT("wrong input size"));
161173
}
162174

175+
// Extract output
176+
mp_get_buffer_raise(output_obj, &bufinfo, MP_BUFFER_RW);
177+
if (bufinfo.typecode != 'f') {
178+
mp_raise_ValueError(MP_ERROR_TEXT("expecting float array"));
179+
}
180+
float *output_buffer = bufinfo.buf;
181+
const int output_length = bufinfo.len / sizeof(*output_buffer);
182+
183+
184+
// check buffer size wrt input
185+
const int expect_out_length = o->out_dims[1]*o->out_dims[2]*o->out_dims[3];
186+
if (output_length != expect_out_length) {
187+
mp_raise_ValueError(MP_ERROR_TEXT("wrong output size"));
188+
}
189+
163190
// Preprocess data
164191
tm_mat_t in_uint8 = o->input;
165192
in_uint8.data = (mtype_t*)input_buffer;
@@ -181,27 +208,38 @@ static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj) {
181208
mp_raise_ValueError(MP_ERROR_TEXT("run error"));
182209
}
183210

211+
// Copy output into
184212
tm_mat_t out = outs[0];
185-
float* data = out.dataf;
186-
float maxp = 0;
187-
int maxi = -1;
188-
189-
// TODO: pass the entire output vector out to Python
190-
// FIXME: unhardcode output handling
191-
for(int i=0; i<10; i++){
192-
//printf("%d: %.3f\n", i, data[i]);
193-
if (data[i] > maxp) {
194-
maxi = i;
195-
maxp = data[i];
196-
}
213+
for(int i=0; i<expect_out_length; i++){
214+
output_buffer[i] = out.dataf[i];
215+
}
216+
217+
return mp_const_none;
218+
}
219+
static MP_DEFINE_CONST_FUN_OBJ_3(mod_cnn_run_obj, mod_cnn_run);
220+
221+
222+
// Return the shape of the output
223+
static mp_obj_t mod_cnn_output_dimensions(mp_obj_t self_obj) {
224+
225+
mp_obj_mod_cnn_t *o = MP_OBJ_TO_PTR(self_obj);
226+
const int dimensions = o->out_dims[0];
227+
mp_obj_tuple_t *tuple = MP_OBJ_TO_PTR(mp_obj_new_tuple(dimensions, NULL));
228+
229+
// A regular output should have C channels, and 1 for everything else
230+
// TODO: support other shapes?
231+
//dims==1, 11c
232+
if (!(o->out_dims[0] == 1 && o->out_dims[1] == 1 && o->out_dims[2] == 1)) {
233+
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("wrong output shape"));
197234
}
198235

199-
return mp_obj_new_int(maxi);
236+
tuple->items[0] = mp_obj_new_int(o->out_dims[3]);
237+
return tuple;
200238
}
201-
static MP_DEFINE_CONST_FUN_OBJ_2(mod_cnn_run_obj, mod_cnn_run);
239+
static MP_DEFINE_CONST_FUN_OBJ_1(mod_cnn_output_dimensions_obj, mod_cnn_output_dimensions);
202240

203241

204-
mp_map_elem_t mod_locals_dict_table[2];
242+
mp_map_elem_t mod_locals_dict_table[3];
205243
static MP_DEFINE_CONST_DICT(mod_locals_dict, mod_locals_dict_table);
206244

207245
// This is the entry point and is called when the module is imported
@@ -217,6 +255,7 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
217255
// methods
218256
mod_locals_dict_table[0] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_run), MP_OBJ_FROM_PTR(&mod_cnn_run_obj) };
219257
mod_locals_dict_table[1] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&mod_cnn_del_obj) };
258+
mod_locals_dict_table[2] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_output_dimensions), MP_OBJ_FROM_PTR(&mod_cnn_output_dimensions_obj) };
220259

221260
MP_OBJ_TYPE_SET_SLOT(&mod_cnn_type, locals_dict, (void*)&mod_locals_dict, 2);
222261

tests/test_cnn.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ def test_cnn_create():
1212
model_data = array.array('B', f.read())
1313
model = emlearn_cnn.new(model_data)
1414

15+
out_shape = model.output_dimensions()
16+
assert out_shape == (10,), (out_shape)
17+
1518
# TODO: enable these checks
1619
#wrong_type = array.array('f', [])
1720
#model.run(wrong_type)
@@ -35,13 +38,25 @@ def print_2d_buffer(arr, rowstride):
3538

3639
print('\n')
3740

41+
def argmax(arr):
42+
idx_max = 0
43+
value_max = arr[0]
44+
for i in range(1, len(arr)):
45+
if arr[i] > value_max:
46+
value_max = arr[i]
47+
idx_max = i
48+
49+
return idx_max
50+
3851
def test_cnn_mnist():
3952

4053
model = None
4154
with open(MNIST_MODEL, 'rb') as f:
4255
model_data = array.array('B', f.read())
4356
model = emlearn_cnn.new(model_data)
4457

58+
probabilities = array.array('f', (-1 for _ in range(10)))
59+
4560
correct = 0
4661
for class_no in range(0, 10):
4762
data_path = MNIST_DATA_DIR + 'mnist_example_{0:d}.bin'.format(class_no)
@@ -51,7 +66,8 @@ def test_cnn_mnist():
5166

5267
#print_2d_buffer(img, 28)
5368

54-
out = model.run(img)
69+
model.run(img, probabilities)
70+
out = argmax(probabilities)
5571
# TODO replace with assert
5672
print('mnist-example-check', class_no, out, class_no == out)
5773
if out == class_no:

0 commit comments

Comments
 (0)