Skip to content

Commit 07fb735

Browse files
author
YavorGIvanov
committed
Attemp to make image encoder work with ggml-alloc
1 parent 272da96 commit 07fb735

File tree

1 file changed

+82
-82
lines changed

1 file changed

+82
-82
lines changed

sam.cpp

Lines changed: 82 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,12 @@ struct sam_state {
236236
// buffer for `ggml_graph_plan.work_data`
237237
std::vector<uint8_t> work_buffer;
238238
// buffers to evaluate the model
239-
std::vector<uint8_t> buf_alloc;
240-
std::vector<uint8_t> buf_compute;
239+
std::vector<uint8_t> buf_alloc_img_enc;
240+
std::vector<uint8_t> buf_compute_img_enc;
241+
242+
std::vector<uint8_t> buf_alloc_fast;
243+
std::vector<uint8_t> buf_compute_fast;
244+
241245
struct ggml_allocr * allocr = {};
242246
};
243247

@@ -1142,11 +1146,10 @@ struct ggml_tensor* sam_layer_norm_2d(
11421146
return layer;
11431147
}
11441148

1145-
bool sam_encode_image(
1149+
struct ggml_cgraph * sam_encode_image(
11461150
const sam_model & model,
11471151
sam_state & state,
1148-
const sam_image_f32 & img,
1149-
int n_threads) {
1152+
const sam_image_f32 & img) {
11501153

11511154
const auto & hparams = model.hparams;
11521155
const auto & enc = model.enc_img;
@@ -1159,30 +1162,18 @@ bool sam_encode_image(
11591162
const int32_t n_img_size = hparams.n_img_size();
11601163
const int32_t n_window_size = hparams.n_window_size();
11611164

1162-
const size_t tensor_alignment = 32;
1163-
1164-
static size_t buf_size = 256u*1024*1024;
1165-
static void * buf = malloc(buf_size);
1166-
1167-
// use 2 scratch buffers
1168-
// TODO: very hacky solution - reimplement in a more elegant way
1169-
static size_t scr0_size = 2048u*1024*1024;
1170-
static void * scr0 = malloc(scr0_size);
1171-
1172-
static size_t scr1_size = 512u*1024*1024;
1173-
static void * scr1 = malloc(scr1_size);
1174-
1175-
struct ggml_init_params params = {
1176-
/*.mem_size =*/ buf_size,
1177-
/*.mem_buffer =*/ buf,
1178-
/*.no_alloc =*/ false,
1165+
struct ggml_init_params ggml_params = {
1166+
/*.mem_size =*/ state.buf_compute_img_enc.size(),
1167+
/*.mem_buffer =*/ state.buf_compute_img_enc.data(),
1168+
/*.no_alloc =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements
11791169
};
11801170

1181-
struct ggml_context * ctx0 = ggml_init(params);
1182-
struct ggml_cgraph gf = {};
1171+
struct ggml_context * ctx0 = ggml_init(ggml_params);
1172+
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
11831173

11841174
struct ggml_tensor * inp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_img_size, n_img_size, 3, 1);
1185-
{
1175+
ggml_allocr_alloc(state.allocr, inp);
1176+
if (!ggml_allocr_is_measure(state.allocr)) {
11861177
float * data = (float *) ggml_get_data(inp);
11871178

11881179
const int nx = img.nx;
@@ -1226,8 +1217,6 @@ bool sam_encode_image(
12261217
for (int il = 0; il < n_enc_layer; ++il) {
12271218
const auto & layer = enc.layers[il];
12281219

1229-
ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
1230-
12311220
// norm
12321221
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L168
12331222
{
@@ -1253,8 +1242,6 @@ bool sam_encode_image(
12531242
const int64_t W = cur->ne[1];
12541243
const int64_t H = cur->ne[2];
12551244

1256-
ggml_set_scratch(ctx0, { 0, scr1_size, scr1, });
1257-
12581245
// self-attention
12591246
{
12601247
cur = ggml_mul_mat(ctx0, layer.qkv_w, cur);
@@ -1290,8 +1277,6 @@ bool sam_encode_image(
12901277
V = ggml_cont (ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); // transposed
12911278
V = ggml_reshape_3d(ctx0, V, W*H, n_enc_head_dim, B*n_enc_head);
12921279

1293-
ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
1294-
12951280
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
12961281

12971282
struct ggml_tensor * KQ_scaled =
@@ -1341,8 +1326,6 @@ bool sam_encode_image(
13411326

13421327
cur = ggml_add(ctx0, inpL, cur);
13431328

1344-
ggml_set_scratch(ctx0, { 0, scr1_size, scr1, });
1345-
13461329
struct ggml_tensor * inpFF = cur;
13471330

13481331
// feed-forward network
@@ -1384,8 +1367,6 @@ bool sam_encode_image(
13841367
inpL = ggml_add(ctx0, cur, inpFF);
13851368
}
13861369

1387-
ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
1388-
13891370
cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3));
13901371

13911372
cur = ggml_conv_2d_sk_p0(ctx0, enc.neck_conv_0, cur);
@@ -1396,21 +1377,16 @@ bool sam_encode_image(
13961377

13971378
cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b, hparams.eps);
13981379

1399-
// TODO: avoid copy
14001380
cur = ggml_cpy(ctx0, cur, state.embd_img);
14011381

1402-
ggml_set_scratch(ctx0, { 0, 0, nullptr, });
1403-
1404-
// run the computation
1405-
ggml_build_forward_expand(&gf, cur);
1406-
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
1407-
1382+
ggml_build_forward_expand(gf, cur);
14081383
ggml_disconnect_node_from_graph(state.embd_img);
14091384

14101385
//ggml_graph_print(&gf);
14111386

14121387
ggml_free(ctx0);
1413-
return true;
1388+
1389+
return gf;
14141390
}
14151391

14161392

@@ -1619,8 +1595,6 @@ bool sam_decode_mask(
16191595
const auto & dec = model.dec;
16201596
const int n_img_embd = hparams.n_img_embd();
16211597

1622-
// print_t_f32("embd_img", state.embd_img);
1623-
16241598
struct ggml_tensor * tokens = {};
16251599
{
16261600
// Concatenate output tokens
@@ -2024,8 +1998,8 @@ struct ggml_cgraph * sam_build_fast_graph(
20241998
sam_point point) {
20251999

20262000
struct ggml_init_params ggml_params = {
2027-
/*.mem_size =*/ state.buf_compute.size(),
2028-
/*.mem_buffer =*/ state.buf_compute.data(),
2001+
/*.mem_size =*/ state.buf_compute_fast.size(),
2002+
/*.mem_buffer =*/ state.buf_compute_fast.data(),
20292003
/*.no_alloc =*/ true, // skip allocating as we use ggml_alloc to allocate exact memory requirements
20302004
};
20312005

@@ -2174,58 +2148,84 @@ int main(int argc, char ** argv) {
21742148
}
21752149

21762150

2177-
if (!sam_encode_image(model, state, img1, params.n_threads)) {
2178-
fprintf(stderr, "%s: failed to encode image\n", __func__);
2179-
return 1;
2180-
}
2181-
2151+
static const size_t tensor_alignment = 32;
21822152
{
2183-
static const size_t tensor_alignment = 32;
2184-
state.buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
2153+
state.buf_compute_img_enc.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
21852154
state.allocr = ggml_allocr_new_measure(tensor_alignment);
2155+
struct ggml_cgraph * gf_measure = sam_encode_image(model, state, img1);
2156+
if (!gf_measure) {
2157+
fprintf(stderr, "%s: failed to encode image\n", __func__);
2158+
return 1;
2159+
}
21862160

2187-
// TODO: user input
2188-
const sam_point pt = { 414.375f, 162.796875f, };
2189-
{
2190-
// measure memory requirements for the graph
2191-
struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
2192-
if (!gf) {
2193-
fprintf(stderr, "%s: failed to build fast graph to measure\n", __func__);
2194-
return 1;
2195-
}
2161+
size_t alloc_size = ggml_allocr_alloc_graph(state.allocr, gf_measure) + tensor_alignment;
2162+
ggml_allocr_free(state.allocr);
21962163

2197-
size_t alloc_size = ggml_allocr_alloc_graph(state.allocr, gf) + tensor_alignment;
2198-
ggml_allocr_free(state.allocr);
2164+
// recreate allocator with exact memory requirements
2165+
state.buf_alloc_img_enc.resize(alloc_size);
2166+
state.allocr = ggml_allocr_new(state.buf_alloc_img_enc.data(), state.buf_alloc_img_enc.size(), tensor_alignment);
21992167

2200-
// recreate allocator with exact memory requirements
2201-
state.buf_alloc.resize(alloc_size);
2202-
state.allocr = ggml_allocr_new(state.buf_alloc.data(), state.buf_alloc.size(), tensor_alignment);
2168+
// compute the graph with the measured exact memory requirements from above
2169+
ggml_allocr_reset(state.allocr);
2170+
2171+
struct ggml_cgraph * gf = sam_encode_image(model, state, img1);
2172+
if (!gf) {
2173+
fprintf(stderr, "%s: failed to encode image\n", __func__);
2174+
return 1;
22032175
}
22042176

2205-
{
2206-
// compute the graph with the measured exact memory requirements from above
2207-
ggml_allocr_reset(state.allocr);
2177+
ggml_allocr_alloc_graph(state.allocr, gf);
22082178

2209-
struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
2210-
if (!gf) {
2211-
fprintf(stderr, "%s: failed to build fast graph\n", __func__);
2212-
return 1;
2213-
}
2179+
ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads);
22142180

2215-
ggml_allocr_alloc_graph(state.allocr, gf);
2181+
print_t_f32("embd_img", state.embd_img);
22162182

2217-
ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads);
2183+
ggml_allocr_free(state.allocr);
2184+
state.allocr = NULL;
2185+
state.work_buffer.clear();
2186+
}
2187+
{
2188+
state.buf_compute_fast.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
2189+
state.allocr = ggml_allocr_new_measure(tensor_alignment);
22182190

2219-
//print_t_f32("iou_predictions", state.iou_predictions);
2220-
//print_t_f32("low_res_masks", state.low_res_masks);
2191+
// TODO: user input
2192+
const sam_point pt = { 414.375f, 162.796875f, };
2193+
// measure memory requirements for the graph
2194+
struct ggml_cgraph * gf_measure = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
2195+
if (!gf_measure) {
2196+
fprintf(stderr, "%s: failed to build fast graph to measure\n", __func__);
2197+
return 1;
22212198
}
22222199

2223-
if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state)) {
2224-
fprintf(stderr, "%s: failed to write masks\n", __func__);
2200+
size_t alloc_size = ggml_allocr_alloc_graph(state.allocr, gf_measure) + tensor_alignment;
2201+
ggml_allocr_free(state.allocr);
2202+
2203+
// recreate allocator with exact memory requirements
2204+
state.buf_alloc_fast.resize(alloc_size);
2205+
state.allocr = ggml_allocr_new(state.buf_alloc_fast.data(), state.buf_alloc_fast.size(), tensor_alignment);
2206+
2207+
// compute the graph with the measured exact memory requirements from above
2208+
ggml_allocr_reset(state.allocr);
2209+
2210+
struct ggml_cgraph * gf = sam_build_fast_graph(model, state, img0.nx, img0.ny, pt);
2211+
if (!gf) {
2212+
fprintf(stderr, "%s: failed to build fast graph\n", __func__);
22252213
return 1;
22262214
}
22272215

2216+
ggml_allocr_alloc_graph(state.allocr, gf);
2217+
2218+
ggml_graph_compute_helper(state.work_buffer, gf, params.n_threads);
2219+
2220+
//print_t_f32("iou_predictions", state.iou_predictions);
2221+
//print_t_f32("low_res_masks", state.low_res_masks);
22282222
ggml_allocr_free(state.allocr);
2223+
state.allocr = NULL;
2224+
}
2225+
2226+
if (!sam_write_masks(model.hparams, img0.nx, img0.ny, state)) {
2227+
fprintf(stderr, "%s: failed to write masks\n", __func__);
2228+
return 1;
22292229
}
22302230

22312231
// report timing

0 commit comments

Comments
 (0)