Skip to content

Commit bb69f1d

Browse files
committed
trees: Use int16 features instead of float
Means that evaluation does not need floating point Should be considerably smaller and faster, especially on platforms without FPU
1 parent 632f3af commit bb69f1d

File tree

5 files changed

+29
-26
lines changed

5 files changed

+29
-26
lines changed

examples/xor_trees/xor_model.csv

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@ l,1
55
r,0
66
r,5
77
r,9
8-
n,1,0.197349,1,2
9-
n,0,0.466316,-1,-2
10-
n,1,0.256702,1,2
11-
n,0,0.464752,-1,-2
12-
n,0,0.476659,-2,-1
13-
n,1,0.897163,1,-2
14-
n,0,0.504671,1,2
15-
n,1,0.501874,-1,-2
16-
n,1,0.498237,-2,-1
17-
n,0,0.11973,1,2
18-
n,1,0.394697,-1,-2
19-
n,1,0.261566,1,2
20-
n,0,0.473531,-1,-2
21-
n,0,0.421164,-2,-1
8+
n,1,6466.0,1,2
9+
n,0,15279.5,-1,-2
10+
n,1,8411.0,1,2
11+
n,0,15228.0,-1,-2
12+
n,0,15618.5,-2,-1
13+
n,1,29397.0,1,-2
14+
n,0,16536.0,1,2
15+
n,1,16444.5,-1,-2
16+
n,1,16325.5,-2,-1
17+
n,0,3922.5,1,2
18+
n,1,12932.5,-1,-2
19+
n,1,8570.5,1,2
20+
n,0,15516.0,-1,-2
21+
n,0,13800.0,-2,-1

examples/xor_trees/xor_train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# python/host code
22

33
import emlearn
4+
from emlearn.preprocessing import Quantizer
45
import numpy
56
from sklearn.ensemble import RandomForestClassifier
67
from sklearn.metrics import get_scorer
@@ -10,6 +11,7 @@ def make_xor(lower=0.0, upper=1.0, threshold=0.5, samples=100, seed=42):
1011
rng = numpy.random.RandomState(seed)
1112
X = rng.uniform(lower, upper, size=(samples, 2))
1213
y = numpy.logical_xor(X[:, 0] > threshold, X[:, 1] > threshold)
14+
X = Quantizer(max_value=1.0).fit_transform(X) # convert to int16
1315
return X, y
1416

1517
X, y = make_xor()

src/emltrees/trees.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ static mp_obj_t builder_addnode(size_t n_args, const mp_obj_t *args) {
125125

126126
const int16_t left = mp_obj_get_int(args[1]);
127127
const int16_t right = mp_obj_get_int(args[2]);
128-
const mp_int_t feature = mp_obj_get_int(args[3]);
129-
const float value = mp_obj_get_float_to_f(args[4]);
128+
const int feature = mp_obj_get_int(args[3]);
129+
const int16_t value = mp_obj_get_int(args[4]);
130130

131131
if (feature > 127 || feature < -1) {
132132
mp_raise_ValueError(MP_ERROR_TEXT("feature out of bounds"));
@@ -191,7 +191,7 @@ static mp_obj_t builder_addleaf(mp_obj_t self_obj, mp_obj_t leaf_obj) {
191191
static MP_DEFINE_CONST_FUN_OBJ_2(builder_addleaf_obj, builder_addleaf);
192192

193193

194-
// Takes a float array
194+
// Takes a array of input data
195195
static mp_obj_t builder_predict(mp_obj_t self_obj, mp_obj_t features_obj) {
196196

197197
mp_obj_trees_builder_t *o = MP_OBJ_TO_PTR(self_obj);
@@ -200,11 +200,11 @@ static mp_obj_t builder_predict(mp_obj_t self_obj, mp_obj_t features_obj) {
200200
// Extract buffer pointer and verify typecode
201201
mp_buffer_info_t bufinfo;
202202
mp_get_buffer_raise(features_obj, &bufinfo, MP_BUFFER_RW);
203-
if (bufinfo.typecode != 'f') {
204-
mp_raise_ValueError(MP_ERROR_TEXT("expecting float array"));
203+
if (bufinfo.typecode != 'h') {
204+
mp_raise_ValueError(MP_ERROR_TEXT("expecting int16 (h) array"));
205205
}
206206

207-
float *features = bufinfo.buf;
207+
const int16_t *features = bufinfo.buf;
208208
const int n_features = bufinfo.len / sizeof(*features);
209209

210210
#if EMLEARN_MICROPYTHON_DEBUG

src/emltrees/trees.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def load_model(builder, f):
1616
builder.addroot(root)
1717
elif kind == 'n':
1818
feature = int(tok[1])
19-
value = float(tok[2])
19+
value = int(float(tok[2]))
2020
left = int(tok[3])
2121
right = int(tok[4])
2222
builder.addnode(left, right, feature, value)

tests/test_trees.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,17 @@ def test_trees_xor():
3636
emltrees.load_model(model, f)
3737

3838
# run it
39+
s = 32767 # max int16
3940
examples = [
4041
# input, expected output
41-
( [0.0, 0.0], 0 ),
42-
( [1.0, 1.0], 0 ),
43-
( [0.0, 1.0], 1 ),
44-
( [1.0, 0.0], 1 ),
42+
( [0, 0], 0 ),
43+
( [1*s, 1*s], 0 ),
44+
( [0, 1*s], 1 ),
45+
( [1*s, 0], 1 ),
4546
]
4647

4748
for (ex, expect) in examples:
48-
f = array.array('f', ex)
49+
f = array.array('h', ex)
4950
result = model.predict(f)
5051
assert result == expect, (ex, expect, result)
5152

0 commit comments

Comments
 (0)