@@ -197,8 +197,25 @@ static mp_obj_t builder_addleaf(mp_obj_t self_obj, mp_obj_t leaf_obj) {
197
197
static MP_DEFINE_CONST_FUN_OBJ_2 (builder_addleaf_obj , builder_addleaf ) ;
198
198
199
199
200
+ // Return the shape of the output
201
+ static mp_obj_t builder_get_outputs (mp_obj_t self_obj ) {
202
+
203
+ mp_obj_trees_builder_t * o = MP_OBJ_TO_PTR (self_obj );
204
+ EmlTreesBuilder * self = & o -> builder ;
205
+
206
+ const int n_classes = self -> trees .n_classes ;
207
+ if (n_classes == 0 ) {
208
+ mp_raise_ValueError (MP_ERROR_TEXT ("model not loaded" ));
209
+ }
210
+
211
+ return mp_obj_new_int (n_classes );
212
+ }
213
+ static MP_DEFINE_CONST_FUN_OBJ_1 (builder_get_outputs_obj , builder_get_outputs ) ;
214
+
215
+
216
+
200
217
// Takes a array of input data
201
- static mp_obj_t builder_predict (mp_obj_t self_obj , mp_obj_t features_obj ) {
218
+ static mp_obj_t builder_predict (mp_obj_t self_obj , mp_obj_t features_obj , mp_obj_t output_obj ) {
202
219
203
220
mp_obj_trees_builder_t * o = MP_OBJ_TO_PTR (self_obj );
204
221
EmlTreesBuilder * self = & o -> builder ;
@@ -212,28 +229,45 @@ static mp_obj_t builder_predict(mp_obj_t self_obj, mp_obj_t features_obj) {
212
229
213
230
const int16_t * features = bufinfo .buf ;
214
231
const int n_features = bufinfo .len / sizeof (* features );
232
+ const int n_outputs = self -> trees .n_classes ;
215
233
216
234
#if EMLEARN_MICROPYTHON_DEBUG
217
235
mp_printf (& mp_plat_print ,
218
- "emltrees-predict n_features=%d n_classes=%d leaves=%d nodes=%d trees=%d length=%d \n" ,
236
+ "emltrees-predict n_features=%d n_classes=%d leaves=%d nodes=%d trees=%d length=%d outputs=%d \n" ,
219
237
self -> trees .n_features , self -> trees .n_classes ,
220
238
self -> trees .n_leaves , self -> trees .n_nodes , self -> trees .n_trees ,
221
239
n_features
222
240
);
223
241
#endif
224
242
243
+ if (n_features == 0 || n_outputs == 0 ) {
244
+ mp_raise_ValueError (MP_ERROR_TEXT ("model not loaded" ));
245
+ }
246
+
247
+ // Extract output
248
+ mp_get_buffer_raise (output_obj , & bufinfo , MP_BUFFER_RW );
249
+ if (bufinfo .typecode != 'f' ) {
250
+ mp_raise_ValueError (MP_ERROR_TEXT ("expecting float output array" ));
251
+ }
252
+ float * output_buffer = bufinfo .buf ;
253
+ const int output_length = bufinfo .len / sizeof (* output_buffer );
254
+
255
+
225
256
// call model
226
- const int result = eml_trees_predict (& self -> trees , features , n_features );
227
- if (result < 0 ) {
228
- mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("eml_trees_predict error" ));
257
+ // NOTE: also handles checking of input and output lengths
258
+ const EmlError err = \
259
+ eml_trees_predict_proba (& self -> trees , features , n_features , output_buffer , output_length );
260
+
261
+ if (err != EmlOk ) {
262
+ mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("eml_trees_predict_proba error" ));
229
263
}
230
264
231
- return mp_obj_new_int ( result ) ;
265
+ return mp_const_none ;
232
266
}
233
- static MP_DEFINE_CONST_FUN_OBJ_2 (builder_predict_obj , builder_predict ) ;
267
+ static MP_DEFINE_CONST_FUN_OBJ_3 (builder_predict_obj , builder_predict ) ;
234
268
235
269
236
- mp_map_elem_t trees_locals_dict_table [6 ];
270
+ mp_map_elem_t trees_locals_dict_table [7 ];
237
271
static MP_DEFINE_CONST_DICT (trees_locals_dict , trees_locals_dict_table ) ;
238
272
239
273
// This is the entry point and is called when the module is imported
@@ -253,8 +287,9 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
253
287
trees_locals_dict_table [3 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_addleaf ), MP_OBJ_FROM_PTR (& builder_addleaf_obj ) };
254
288
trees_locals_dict_table [4 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR___del__ ), MP_OBJ_FROM_PTR (& builder_del_obj ) };
255
289
trees_locals_dict_table [5 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_setdata ), MP_OBJ_FROM_PTR (& builder_setdata_obj ) };
290
+ trees_locals_dict_table [6 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_outputs ), MP_OBJ_FROM_PTR (& builder_get_outputs_obj ) };
256
291
257
- MP_OBJ_TYPE_SET_SLOT (& trees_builder_type , locals_dict , (void * )& trees_locals_dict , 6 );
292
+ MP_OBJ_TYPE_SET_SLOT (& trees_builder_type , locals_dict , (void * )& trees_locals_dict , 7 );
258
293
259
294
// This must be last, it restores the globals dict
260
295
MP_DYNRUNTIME_INIT_EXIT
0 commit comments