Skip to content

Commit b48370a

Browse files
committed
cnn: Verify output dimension from tm_run
1 parent ef7ddd0 commit b48370a

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

src/tinymaix_cnn/mod_cnn.c

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
#include <tinymaix.h>
88

99
#include "tm_layers.c"
10+
//#include "tm_layers_O1.c"
1011
#include "tm_model.c"
1112
//#include "tm_stat.c"
1213

1314
#include <string.h>
1415

16+
#define DEBUG (1)
17+
1518

1619
// memset is used by some standard C constructs
1720
#if !defined(__linux__)
@@ -58,10 +61,13 @@ int TM_WEAK tm_get_outputs(tm_mdl_t* mdl, tm_mat_t* out, int out_length)
5861

5962
static tm_err_t layer_cb(tm_mdl_t* mdl, tml_head_t* lh)
6063
{
64+
#if DEBUG
65+
mp_printf(&mp_plat_print, "cnn-layer-cb type=%d \n", lh->type);
66+
#endif
67+
6168
return TM_OK;
6269
}
6370

64-
#define DEBUG (1)
6571

6672
// MicroPython type
6773
typedef struct _mp_obj_mod_cnn_t {
@@ -209,12 +215,22 @@ static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj, mp_obj_t outp
209215
mp_raise_ValueError(MP_ERROR_TEXT("run error"));
210216
}
211217

212-
// Copy output into
218+
// Copy output
213219
tm_mat_t out = outs[0];
220+
221+
#if DEBUG
222+
mp_printf(&mp_plat_print, "cnn-run out.dims=(%d,%d,%d,%d) out.length=%d expect_length=%d \n",
223+
out.dims, out.h, out.w, out.c, expect_out_length
224+
);
225+
#endif
226+
227+
if (!((out.dims == 1) && (out.h == 1) && (out.w == 1) && out.c == expect_out_length)) {
228+
mp_raise_msg(&mp_type_RuntimeError, MP_ERROR_TEXT("unexpected output dims"));
229+
}
230+
214231
for(int i=0; i<expect_out_length; i++){
215232
output_buffer[i] = out.dataf[i];
216233
}
217-
218234
return mp_const_none;
219235
}
220236
static MP_DEFINE_CONST_FUN_OBJ_3(mod_cnn_run_obj, mod_cnn_run);

0 commit comments

Comments
 (0)