@@ -236,8 +236,12 @@ struct sam_state {
236236 // buffer for `ggml_graph_plan.work_data`
237237 std::vector<uint8_t > work_buffer;
238238 // buffers to evaluate the model
239- std::vector<uint8_t > buf_alloc;
240- std::vector<uint8_t > buf_compute;
239+ std::vector<uint8_t > buf_alloc_img_enc;
240+ std::vector<uint8_t > buf_compute_img_enc;
241+
242+ std::vector<uint8_t > buf_alloc_fast;
243+ std::vector<uint8_t > buf_compute_fast;
244+
241245 struct ggml_allocr * allocr = {};
242246};
243247
@@ -1142,11 +1146,10 @@ struct ggml_tensor* sam_layer_norm_2d(
11421146 return layer;
11431147}
11441148
1145- bool sam_encode_image (
1149+ struct ggml_cgraph * sam_encode_image (
11461150 const sam_model & model,
11471151 sam_state & state,
1148- const sam_image_f32 & img,
1149- int n_threads) {
1152+ const sam_image_f32 & img) {
11501153
11511154 const auto & hparams = model.hparams ;
11521155 const auto & enc = model.enc_img ;
@@ -1159,30 +1162,18 @@ bool sam_encode_image(
11591162 const int32_t n_img_size = hparams.n_img_size ();
11601163 const int32_t n_window_size = hparams.n_window_size ();
11611164
1162- const size_t tensor_alignment = 32 ;
1163-
1164- static size_t buf_size = 256u *1024 *1024 ;
1165- static void * buf = malloc (buf_size);
1166-
1167- // use 2 scratch buffers
1168- // TODO: very hacky solution - reimplement in a more elegant way
1169- static size_t scr0_size = 2048u *1024 *1024 ;
1170- static void * scr0 = malloc (scr0_size);
1171-
1172- static size_t scr1_size = 512u *1024 *1024 ;
1173- static void * scr1 = malloc (scr1_size);
1174-
1175- struct ggml_init_params params = {
1176- /* .mem_size =*/ buf_size,
1177- /* .mem_buffer =*/ buf,
1178- /* .no_alloc =*/ false ,
1165+ struct ggml_init_params ggml_params = {
1166+ /* .mem_size =*/ state.buf_compute_img_enc .size (),
1167+ /* .mem_buffer =*/ state.buf_compute_img_enc .data (),
1168+ /* .no_alloc =*/ true , // skip allocating as we use ggml_alloc to allocate exact memory requirements
11791169 };
11801170
1181- struct ggml_context * ctx0 = ggml_init (params );
1182- struct ggml_cgraph gf = {} ;
1171+ struct ggml_context * ctx0 = ggml_init (ggml_params );
1172+ struct ggml_cgraph * gf = ggml_new_graph (ctx0) ;
11831173
11841174 struct ggml_tensor * inp = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_img_size, n_img_size, 3 , 1 );
1185- {
1175+ ggml_allocr_alloc (state.allocr , inp);
1176+ if (!ggml_allocr_is_measure (state.allocr )) {
11861177 float * data = (float *) ggml_get_data (inp);
11871178
11881179 const int nx = img.nx ;
@@ -1226,8 +1217,6 @@ bool sam_encode_image(
12261217 for (int il = 0 ; il < n_enc_layer; ++il) {
12271218 const auto & layer = enc.layers [il];
12281219
1229- ggml_set_scratch (ctx0, { 0 , scr0_size, scr0, });
1230-
12311220 // norm
12321221 // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L168
12331222 {
@@ -1253,8 +1242,6 @@ bool sam_encode_image(
12531242 const int64_t W = cur->ne [1 ];
12541243 const int64_t H = cur->ne [2 ];
12551244
1256- ggml_set_scratch (ctx0, { 0 , scr1_size, scr1, });
1257-
12581245 // self-attention
12591246 {
12601247 cur = ggml_mul_mat (ctx0, layer.qkv_w , cur);
@@ -1290,8 +1277,6 @@ bool sam_encode_image(
12901277 V = ggml_cont (ctx0, ggml_permute (ctx0, V, 1 , 2 , 0 , 3 )); // transposed
12911278 V = ggml_reshape_3d (ctx0, V, W*H, n_enc_head_dim, B*n_enc_head);
12921279
1293- ggml_set_scratch (ctx0, { 0 , scr0_size, scr0, });
1294-
12951280 struct ggml_tensor * KQ = ggml_mul_mat (ctx0, K, Q);
12961281
12971282 struct ggml_tensor * KQ_scaled =
@@ -1341,8 +1326,6 @@ bool sam_encode_image(
13411326
13421327 cur = ggml_add (ctx0, inpL, cur);
13431328
1344- ggml_set_scratch (ctx0, { 0 , scr1_size, scr1, });
1345-
13461329 struct ggml_tensor * inpFF = cur;
13471330
13481331 // feed-forward network
@@ -1384,8 +1367,6 @@ bool sam_encode_image(
13841367 inpL = ggml_add (ctx0, cur, inpFF);
13851368 }
13861369
1387- ggml_set_scratch (ctx0, { 0 , scr0_size, scr0, });
1388-
13891370 cur = ggml_cont (ctx0, ggml_permute (ctx0, inpL, 2 , 0 , 1 , 3 ));
13901371
13911372 cur = ggml_conv_2d_sk_p0 (ctx0, enc.neck_conv_0 , cur);
@@ -1396,21 +1377,16 @@ bool sam_encode_image(
13961377
13971378 cur = sam_layer_norm_2d (ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w , enc.neck_norm_1_b , hparams.eps );
13981379
1399- // TODO: avoid copy
14001380 cur = ggml_cpy (ctx0, cur, state.embd_img );
14011381
1402- ggml_set_scratch (ctx0, { 0 , 0 , nullptr , });
1403-
1404- // run the computation
1405- ggml_build_forward_expand (&gf, cur);
1406- ggml_graph_compute_with_ctx (ctx0, &gf, n_threads);
1407-
1382+ ggml_build_forward_expand (gf, cur);
14081383 ggml_disconnect_node_from_graph (state.embd_img );
14091384
14101385 // ggml_graph_print(&gf);
14111386
14121387 ggml_free (ctx0);
1413- return true ;
1388+
1389+ return gf;
14141390}
14151391
14161392
@@ -1619,8 +1595,6 @@ bool sam_decode_mask(
16191595 const auto & dec = model.dec ;
16201596 const int n_img_embd = hparams.n_img_embd ();
16211597
1622- // print_t_f32("embd_img", state.embd_img);
1623-
16241598 struct ggml_tensor * tokens = {};
16251599 {
16261600 // Concatenate output tokens
@@ -2024,8 +1998,8 @@ struct ggml_cgraph * sam_build_fast_graph(
20241998 sam_point point) {
20251999
20262000 struct ggml_init_params ggml_params = {
2027- /* .mem_size =*/ state.buf_compute .size (),
2028- /* .mem_buffer =*/ state.buf_compute .data (),
2001+ /* .mem_size =*/ state.buf_compute_fast .size (),
2002+ /* .mem_buffer =*/ state.buf_compute_fast .data (),
20292003 /* .no_alloc =*/ true , // skip allocating as we use ggml_alloc to allocate exact memory requirements
20302004 };
20312005
@@ -2174,58 +2148,84 @@ int main(int argc, char ** argv) {
21742148 }
21752149
21762150
2177- if (!sam_encode_image (model, state, img1, params.n_threads )) {
2178- fprintf (stderr, " %s: failed to encode image\n " , __func__);
2179- return 1 ;
2180- }
2181-
2151+ static const size_t tensor_alignment = 32 ;
21822152 {
2183- static const size_t tensor_alignment = 32 ;
2184- state.buf_compute .resize (ggml_tensor_overhead ()*GGML_MAX_NODES + ggml_graph_overhead ());
2153+ state.buf_compute_img_enc .resize (ggml_tensor_overhead ()*GGML_MAX_NODES + ggml_graph_overhead ());
21852154 state.allocr = ggml_allocr_new_measure (tensor_alignment);
2155+ struct ggml_cgraph * gf_measure = sam_encode_image (model, state, img1);
2156+ if (!gf_measure) {
2157+ fprintf (stderr, " %s: failed to encode image\n " , __func__);
2158+ return 1 ;
2159+ }
21862160
2187- // TODO: user input
2188- const sam_point pt = { 414 .375f , 162 .796875f , };
2189- {
2190- // measure memory requirements for the graph
2191- struct ggml_cgraph * gf = sam_build_fast_graph (model, state, img0.nx , img0.ny , pt);
2192- if (!gf) {
2193- fprintf (stderr, " %s: failed to build fast graph to measure\n " , __func__);
2194- return 1 ;
2195- }
2161+ size_t alloc_size = ggml_allocr_alloc_graph (state.allocr , gf_measure) + tensor_alignment;
2162+ ggml_allocr_free (state.allocr );
21962163
2197- size_t alloc_size = ggml_allocr_alloc_graph (state.allocr , gf) + tensor_alignment;
2198- ggml_allocr_free (state.allocr );
2164+ // recreate allocator with exact memory requirements
2165+ state.buf_alloc_img_enc .resize (alloc_size);
2166+ state.allocr = ggml_allocr_new (state.buf_alloc_img_enc .data (), state.buf_alloc_img_enc .size (), tensor_alignment);
21992167
2200- // recreate allocator with exact memory requirements
2201- state.buf_alloc .resize (alloc_size);
2202- state.allocr = ggml_allocr_new (state.buf_alloc .data (), state.buf_alloc .size (), tensor_alignment);
2168+ // compute the graph with the measured exact memory requirements from above
2169+ ggml_allocr_reset (state.allocr );
2170+
2171+ struct ggml_cgraph * gf = sam_encode_image (model, state, img1);
2172+ if (!gf) {
2173+ fprintf (stderr, " %s: failed to encode image\n " , __func__);
2174+ return 1 ;
22032175 }
22042176
2205- {
2206- // compute the graph with the measured exact memory requirements from above
2207- ggml_allocr_reset (state.allocr );
2177+ ggml_allocr_alloc_graph (state.allocr , gf);
22082178
2209- struct ggml_cgraph * gf = sam_build_fast_graph (model, state, img0.nx , img0.ny , pt);
2210- if (!gf) {
2211- fprintf (stderr, " %s: failed to build fast graph\n " , __func__);
2212- return 1 ;
2213- }
2179+ ggml_graph_compute_helper (state.work_buffer , gf, params.n_threads );
22142180
2215- ggml_allocr_alloc_graph (state. allocr , gf );
2181+ print_t_f32 ( " embd_img " , state. embd_img );
22162182
2217- ggml_graph_compute_helper (state.work_buffer , gf, params.n_threads );
2183+ ggml_allocr_free (state.allocr );
2184+ state.allocr = NULL ;
2185+ state.work_buffer .clear ();
2186+ }
2187+ {
2188+ state.buf_compute_fast .resize (ggml_tensor_overhead ()*GGML_MAX_NODES + ggml_graph_overhead ());
2189+ state.allocr = ggml_allocr_new_measure (tensor_alignment);
22182190
2219- // print_t_f32("iou_predictions", state.iou_predictions);
2220- // print_t_f32("low_res_masks", state.low_res_masks);
2191+ // TODO: user input
2192+ const sam_point pt = { 414 .375f , 162 .796875f , };
2193+ // measure memory requirements for the graph
2194+ struct ggml_cgraph * gf_measure = sam_build_fast_graph (model, state, img0.nx , img0.ny , pt);
2195+ if (!gf_measure) {
2196+ fprintf (stderr, " %s: failed to build fast graph to measure\n " , __func__);
2197+ return 1 ;
22212198 }
22222199
2223- if (!sam_write_masks (model.hparams , img0.nx , img0.ny , state)) {
2224- fprintf (stderr, " %s: failed to write masks\n " , __func__);
2200+ size_t alloc_size = ggml_allocr_alloc_graph (state.allocr , gf_measure) + tensor_alignment;
2201+ ggml_allocr_free (state.allocr );
2202+
2203+ // recreate allocator with exact memory requirements
2204+ state.buf_alloc_fast .resize (alloc_size);
2205+ state.allocr = ggml_allocr_new (state.buf_alloc_fast .data (), state.buf_alloc_fast .size (), tensor_alignment);
2206+
2207+ // compute the graph with the measured exact memory requirements from above
2208+ ggml_allocr_reset (state.allocr );
2209+
2210+ struct ggml_cgraph * gf = sam_build_fast_graph (model, state, img0.nx , img0.ny , pt);
2211+ if (!gf) {
2212+ fprintf (stderr, " %s: failed to build fast graph\n " , __func__);
22252213 return 1 ;
22262214 }
22272215
2216+ ggml_allocr_alloc_graph (state.allocr , gf);
2217+
2218+ ggml_graph_compute_helper (state.work_buffer , gf, params.n_threads );
2219+
2220+ // print_t_f32("iou_predictions", state.iou_predictions);
2221+ // print_t_f32("low_res_masks", state.low_res_masks);
22282222 ggml_allocr_free (state.allocr );
2223+ state.allocr = NULL ;
2224+ }
2225+
2226+ if (!sam_write_masks (model.hparams , img0.nx , img0.ny , state)) {
2227+ fprintf (stderr, " %s: failed to write masks\n " , __func__);
2228+ return 1 ;
22292229 }
22302230
22312231 // report timing
0 commit comments