@@ -34,39 +34,31 @@ void NORETURN abort() {
34
34
}
35
35
#endif
36
36
37
-
38
- static tm_err_t layer_cb (tm_mdl_t * mdl , tml_head_t * lh )
39
- {
40
- #if 0
41
- //dump middle result
42
- int h = lh -> out_dims [1 ];
43
- int w = lh -> out_dims [2 ];
44
- int ch = lh -> out_dims [3 ];
45
- mtype_t * output = TML_GET_OUTPUT (mdl , lh );
46
- return TM_OK ;
47
- TM_PRINTF ("Layer %d callback ========\n" , mdl -> layer_i );
48
- #if 1
49
- for (int y = 0 ; y < h ; y ++ ){
50
- TM_PRINTF ("[" );
51
- for (int x = 0 ; x < w ; x ++ ){
52
- TM_PRINTF ("[" );
53
- for (int c = 0 ; c < ch ; c ++ ){
54
- #if TM_MDL_TYPE == TM_MDL_FP32
55
- TM_PRINTF ("%.3f," , output [(y * w + x )* ch + c ]);
56
- #else
57
- TM_PRINTF ("%.3f," , TML_DEQUANT (lh ,output [(y * w + x )* ch + c ]));
58
- #endif
37
+ // get model output shapes
38
+ //mdl: model handle; in: input mat; out: output mat
39
+ int TM_WEAK tm_get_outputs (tm_mdl_t * mdl , tm_mat_t * out , int out_length )
40
+ {
41
+ // NOTE: based on tm_run, but without actually executing
42
+ int out_idx = 0 ;
43
+ mdl -> layer_body = mdl -> b -> layers_body ;
44
+ for (mdl -> layer_i = 0 ; mdl -> layer_i < mdl -> b -> layer_cnt ; mdl -> layer_i ++ ){
45
+ tml_head_t * h = (tml_head_t * )(mdl -> layer_body );
46
+ if (h -> is_out ) {
47
+ if (out_idx < out_length ) {
48
+ memcpy ((void * )(& out [out_idx ]), (void * )(& (h -> out_dims )), sizeof (uint16_t )* 4 );
49
+ out_idx += 1 ;
50
+ } else {
51
+ return -1 ;
59
52
}
60
- TM_PRINTF ("]," );
61
53
}
62
- TM_PRINTF ( "],\n" );
54
+ mdl -> layer_body += ( h -> size );
63
55
}
64
- TM_PRINTF ("\n" );
65
- #endif
66
- return TM_OK ;
67
- #else
56
+ return out_idx ;
57
+ }
58
+
59
+ static tm_err_t layer_cb (tm_mdl_t * mdl , tml_head_t * lh )
60
+ {
68
61
return TM_OK ;
69
- #endif
70
62
}
71
63
72
64
#define DEBUG (1)
@@ -79,6 +71,7 @@ typedef struct _mp_obj_mod_cnn_t {
79
71
tm_mat_t input ;
80
72
uint8_t * model_buffer ;
81
73
uint8_t * data_buffer ;
74
+ uint16_t out_dims [4 ];
82
75
} mp_obj_mod_cnn_t ;
83
76
84
77
mp_obj_full_type_t mod_cnn_type ;
@@ -121,6 +114,25 @@ static mp_obj_t mod_cnn_new(mp_obj_t model_data_obj) {
121
114
mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("tm_load error" ));
122
115
}
123
116
117
+ // find model output shape
118
+ o -> out_dims [0 ] = 0 ;
119
+ tm_mat_t outs [1 ];
120
+ const int outputs = tm_get_outputs (model , outs , 1 );
121
+ if (outputs != 1 ) {
122
+ mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("only 1 output supported" ));
123
+ }
124
+ memcpy ((void * )(o -> out_dims ), (void * )(& (outs [0 ])), sizeof (uint16_t )* 4 );
125
+
126
+ if ((o -> out_dims [0 ] != 1 )) {
127
+ mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("output must be 1d" ));
128
+ }
129
+ memcpy ((void * )(o -> out_dims ), (void * )(& (outs [0 ])), sizeof (uint16_t )* 4 );
130
+
131
+ #if DEBUG
132
+ mp_printf (& mp_plat_print , "cnn-new-done outs=%d out.dims=(%d,%d,%d,%d) \n" ,
133
+ outputs , o -> out_dims [0 ], o -> out_dims [1 ], o -> out_dims [2 ], o -> out_dims [3 ]);
134
+ #endif
135
+
124
136
return MP_OBJ_FROM_PTR (o );
125
137
}
126
138
static MP_DEFINE_CONST_FUN_OBJ_1 (mod_cnn_new_obj , mod_cnn_new ) ;
@@ -141,15 +153,15 @@ static MP_DEFINE_CONST_FUN_OBJ_1(mod_cnn_del_obj, mod_cnn_del);
141
153
142
154
143
155
// Add a node to the tree
144
- static mp_obj_t mod_cnn_run (mp_obj_t self_obj , mp_obj_t input_obj ) {
156
+ static mp_obj_t mod_cnn_run (mp_obj_t self_obj , mp_obj_t input_obj , mp_obj_t output_obj ) {
145
157
146
158
mp_obj_mod_cnn_t * o = MP_OBJ_TO_PTR (self_obj );
147
159
148
160
// Extract input
149
161
mp_buffer_info_t bufinfo ;
150
162
mp_get_buffer_raise (input_obj , & bufinfo , MP_BUFFER_RW );
151
163
if (bufinfo .typecode != 'B' ) {
152
- mp_raise_ValueError (MP_ERROR_TEXT ("expecting float array" ));
164
+ mp_raise_ValueError (MP_ERROR_TEXT ("expecting byte array" ));
153
165
}
154
166
uint8_t * input_buffer = bufinfo .buf ;
155
167
const int input_length = bufinfo .len / sizeof (* input_buffer );
@@ -160,6 +172,21 @@ static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj) {
160
172
mp_raise_ValueError (MP_ERROR_TEXT ("wrong input size" ));
161
173
}
162
174
175
+ // Extract output
176
+ mp_get_buffer_raise (output_obj , & bufinfo , MP_BUFFER_RW );
177
+ if (bufinfo .typecode != 'f' ) {
178
+ mp_raise_ValueError (MP_ERROR_TEXT ("expecting float array" ));
179
+ }
180
+ float * output_buffer = bufinfo .buf ;
181
+ const int output_length = bufinfo .len / sizeof (* output_buffer );
182
+
183
+
184
+ // check buffer size wrt input
185
+ const int expect_out_length = o -> out_dims [1 ]* o -> out_dims [2 ]* o -> out_dims [3 ];
186
+ if (output_length != expect_out_length ) {
187
+ mp_raise_ValueError (MP_ERROR_TEXT ("wrong output size" ));
188
+ }
189
+
163
190
// Preprocess data
164
191
tm_mat_t in_uint8 = o -> input ;
165
192
in_uint8 .data = (mtype_t * )input_buffer ;
@@ -181,27 +208,38 @@ static mp_obj_t mod_cnn_run(mp_obj_t self_obj, mp_obj_t input_obj) {
181
208
mp_raise_ValueError (MP_ERROR_TEXT ("run error" ));
182
209
}
183
210
211
+ // Copy output into
184
212
tm_mat_t out = outs [0 ];
185
- float * data = out .dataf ;
186
- float maxp = 0 ;
187
- int maxi = -1 ;
188
-
189
- // TODO: pass the entire output vector out to Python
190
- // FIXME: unhardcode output handling
191
- for (int i = 0 ; i < 10 ; i ++ ){
192
- //printf("%d: %.3f\n", i, data[i]);
193
- if (data [i ] > maxp ) {
194
- maxi = i ;
195
- maxp = data [i ];
196
- }
213
+ for (int i = 0 ; i < expect_out_length ; i ++ ){
214
+ output_buffer [i ] = out .dataf [i ];
215
+ }
216
+
217
+ return mp_const_none ;
218
+ }
219
+ static MP_DEFINE_CONST_FUN_OBJ_3 (mod_cnn_run_obj , mod_cnn_run ) ;
220
+
221
+
222
+ // Return the shape of the output
223
+ static mp_obj_t mod_cnn_output_dimensions (mp_obj_t self_obj ) {
224
+
225
+ mp_obj_mod_cnn_t * o = MP_OBJ_TO_PTR (self_obj );
226
+ const int dimensions = o -> out_dims [0 ];
227
+ mp_obj_tuple_t * tuple = MP_OBJ_TO_PTR (mp_obj_new_tuple (dimensions , NULL ));
228
+
229
+ // A regular output should have C channels, and 1 for everything else
230
+ // TODO: support other shapes?
231
+ //dims==1, 11c
232
+ if (!(o -> out_dims [0 ] == 1 && o -> out_dims [1 ] == 1 && o -> out_dims [2 ] == 1 )) {
233
+ mp_raise_msg (& mp_type_RuntimeError , MP_ERROR_TEXT ("wrong output shape" ));
197
234
}
198
235
199
- return mp_obj_new_int (maxi );
236
+ tuple -> items [0 ] = mp_obj_new_int (o -> out_dims [3 ]);
237
+ return tuple ;
200
238
}
201
- static MP_DEFINE_CONST_FUN_OBJ_2 ( mod_cnn_run_obj , mod_cnn_run ) ;
239
+ static MP_DEFINE_CONST_FUN_OBJ_1 ( mod_cnn_output_dimensions_obj , mod_cnn_output_dimensions ) ;
202
240
203
241
204
- mp_map_elem_t mod_locals_dict_table [2 ];
242
+ mp_map_elem_t mod_locals_dict_table [3 ];
205
243
static MP_DEFINE_CONST_DICT (mod_locals_dict , mod_locals_dict_table ) ;
206
244
207
245
// This is the entry point and is called when the module is imported
@@ -217,6 +255,7 @@ mp_obj_t mpy_init(mp_obj_fun_bc_t *self, size_t n_args, size_t n_kw, mp_obj_t *a
217
255
// methods
218
256
mod_locals_dict_table [0 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_run ), MP_OBJ_FROM_PTR (& mod_cnn_run_obj ) };
219
257
mod_locals_dict_table [1 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR___del__ ), MP_OBJ_FROM_PTR (& mod_cnn_del_obj ) };
258
+ mod_locals_dict_table [2 ] = (mp_map_elem_t ){ MP_OBJ_NEW_QSTR (MP_QSTR_output_dimensions ), MP_OBJ_FROM_PTR (& mod_cnn_output_dimensions_obj ) };
220
259
221
260
MP_OBJ_TYPE_SET_SLOT (& mod_cnn_type , locals_dict , (void * )& mod_locals_dict , 2 );
222
261
0 commit comments