Skip to content

Commit ce2e8bd

Browse files
authored
Merge pull request emlearn#26 from emlearn/trees-proba
trees: Switch to output prediction probabilities
2 parents 30c4a4f + c7203ac commit ce2e8bd

File tree

5 files changed

+85
-13
lines changed

5 files changed

+85
-13
lines changed

examples/har_trees/har_live.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@ def mean(arr):
1515
m = sum(arr) / float(len(arr))
1616
return m
1717

18+
def argmax(arr):
19+
idx_max = 0
20+
value_max = arr[0]
21+
for i in range(1, len(arr)):
22+
if arr[i] > value_max:
23+
value_max = arr[i]
24+
idx_max = i
25+
26+
return idx_max
27+
1828
def copy_array_into(source, target):
1929
assert len(source) == len(target)
2030
for i in range(len(target)):
@@ -63,6 +73,7 @@ def main():
6373
features_typecode = timebased.DATA_TYPECODE
6474
n_features = timebased.N_FEATURES
6575
features = array.array(features_typecode, (0 for _ in range(n_features)))
76+
out = array.array('f', range(model.outputs()))
6677

6778
while True:
6879

@@ -87,7 +98,8 @@ def main():
8798

8899
# Cun classifier
89100
#print(features)
90-
result = model.predict(features)
101+
model.predict(features, out)
102+
result = argmax(out)
91103
activity = class_index_to_name[result]
92104

93105
d = time.ticks_diff(time.ticks_ms(), start)

examples/har_trees/har_run.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66
import emlearn_trees
77
import timebased
88

9+
def argmax(arr):
10+
idx_max = 0
11+
value_max = arr[0]
12+
for i in range(1, len(arr)):
13+
if arr[i] > value_max:
14+
value_max = arr[i]
15+
idx_max = i
16+
17+
return idx_max
18+
919
def har_load_test_data(path,
1020
skip_samples=0, limit_samples=None):
1121

@@ -63,6 +73,7 @@ def main():
6373
with open(model_path, 'r') as f:
6474
emlearn_trees.load_model(model, f)
6575

76+
out = array.array('f', range(model.outputs()))
6677

6778
errors = 0
6879
total = 0
@@ -72,7 +83,8 @@ def main():
7283

7384
assert len(labels) == 1
7485
label = labels[0]
75-
result = model.predict(features)
86+
model.predict(features, out)
87+
result = argmax(out)
7688
if result != label:
7789
errors += 1
7890
total += 1

src/emlearn_trees/trees.c

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,25 @@ static mp_obj_t builder_addleaf(mp_obj_t self_obj, mp_obj_t leaf_obj) {
197197
static MP_DEFINE_CONST_FUN_OBJ_2(builder_addleaf_obj, builder_addleaf);
198198

199199

200+
// Return the shape of the output
201+
static mp_obj_t builder_get_outputs(mp_obj_t self_obj) {
202+
203+
mp_obj_trees_builder_t *o = MP_OBJ_TO_PTR(self_obj);
204+
EmlTreesBuilder *self = &o->builder;
205+
206+
const int n_classes = self->trees.n_classes;
207+
if (n_classes == 0) {
208+
mp_raise_ValueError(MP_ERROR_TEXT("model not loaded"));
209+
}
210+
211+
return mp_obj_new_int(n_classes);
212+
}
213+
static MP_DEFINE_CONST_FUN_OBJ_1(builder_get_outputs_obj, builder_get_outputs);
214+
215+
216+
200217
// Takes a array of input data
201-
static mp_obj_t builder_predict(mp_obj_t self_obj, mp_obj_t features_obj) {
218+
static mp_obj_t builder_predict(mp_obj_t self_obj, mp_obj_t features_obj, mp_obj_t output_obj) {
202219

203220
mp_obj_trees_builder_t *o = MP_OBJ_TO_PTR(self_obj);
204221
EmlTreesBuilder *self = &o->builder;
@@ -212,28 +229,45 @@ static mp_obj_t builder_predict(mp_obj_t self_obj, mp_obj_t features_obj) {
212229

213230
const int16_t *features = bufinfo.buf;
214231
const int n_features = bufinfo.len / sizeof(*features);
232+
const int n_outputs = self->trees.n_classes;
215233

216234
#if EMLEARN_MICROPYTHON_DEBUG
217235
mp_printf(&mp_plat_print,
218-
"emltrees-predict n_features=%d n_classes=%d leaves=%d nodes=%d trees=%d length=%d \n",
236+
"emltrees-predict n_features=%d n_classes=%d leaves=%d nodes=%d trees=%d length=%d outputs=%d \n",
219237
self->trees.n_features, self->trees.n_classes,
220238
self->trees.n_leaves, self->trees.n_nodes, self->trees.n_trees,
221239
n_features
222240
);
223241
#endif
224242

243+
if (n_features == 0 || n_outputs == 0) {
244+
mp_raise_ValueError(MP_ERROR_TEXT("model not loaded"));
245+
}
246+
247+
// Extract output
248+
mp_get_buffer_raise(output_obj, &bufinfo, MP_BUFFER_RW);
249+
if (bufinfo.typecode != 'f') {
250+
mp_raise_ValueError(MP_ERROR_TEXT("expecting float output array"));
251+
}
252+
float *output_buffer = bufinfo.buf;
253+
const int output_length = bufinfo.len / sizeof(*output_buffer);
254+
255+
225256
// call model
226-
const int result = eml_trees_predict(&self->trees, features, n_features);
227-
if (result < 0) {
228-
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("eml_trees_predict error"));
257+
// NOTE: also handles checking of input and output lengths
258+
const EmlError err = \
259+
eml_trees_predict_proba(&self->trees, features, n_features, output_buffer, output_length);
260+
261+
if (err != EmlOk) {
262+
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("eml_trees_predict_proba error"));
229263
}
230264

231-
return mp_obj_new_int(result);
265+
return mp_const_none;
232266
}
233-
static MP_DEFINE_CONST_FUN_OBJ_2(builder_predict_obj, builder_predict);
267+
static MP_DEFINE_CONST_FUN_OBJ_3(builder_predict_obj, builder_predict);
234268

235269

236-
mp_map_elem_t trees_locals_dict_table[6];
270+
mp_map_elem_t trees_locals_dict_table[7];
237271
static MP_DEFINE_CONST_DICT(trees_locals_dict, trees_locals_dict_table);
238272

239273
// This is the entry point and is called when the module is imported
@@ -253,8 +287,9 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
253287
trees_locals_dict_table[3] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_addleaf), MP_OBJ_FROM_PTR(&builder_addleaf_obj) };
254288
trees_locals_dict_table[4] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR___del__), MP_OBJ_FROM_PTR(&builder_del_obj) };
255289
trees_locals_dict_table[5] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_setdata), MP_OBJ_FROM_PTR(&builder_setdata_obj) };
290+
trees_locals_dict_table[6] = (mp_map_elem_t){ MP_OBJ_NEW_QSTR(MP_QSTR_outputs), MP_OBJ_FROM_PTR(&builder_get_outputs_obj) };
256291

257-
MP_OBJ_TYPE_SET_SLOT(&trees_builder_type, locals_dict, (void*)&trees_locals_dict, 6);
292+
MP_OBJ_TYPE_SET_SLOT(&trees_builder_type, locals_dict, (void*)&trees_locals_dict, 7);
258293

259294
// This must be last, it restores the globals dict
260295
MP_DYNRUNTIME_INIT_EXIT

tests/test_cnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_cnn_mnist():
7373
if out == class_no:
7474
correct += 1
7575

76-
assert correct >= 6, correct
76+
assert correct >= 9, correct
7777

7878

7979
test_cnn_create()

tests/test_trees.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,16 @@
44
import array
55
import gc
66

7+
def argmax(arr):
8+
idx_max = 0
9+
value_max = arr[0]
10+
for i in range(1, len(arr)):
11+
if arr[i] > value_max:
12+
value_max = arr[i]
13+
idx_max = i
14+
15+
return idx_max
16+
717
def test_trees_del():
818
"""
919
Deleting the model should free all the memory
@@ -45,9 +55,12 @@ def test_trees_xor():
4555
( [1*s, 0], 1 ),
4656
]
4757

58+
out = array.array('f', range(model.outputs()))
59+
4860
for (ex, expect) in examples:
4961
f = array.array('h', ex)
50-
result = model.predict(f)
62+
model.predict(f, out)
63+
result = argmax(out)
5164
assert result == expect, (ex, expect, result)
5265

5366
test_trees_del()

0 commit comments

Comments
 (0)