From 95b2949251b0d86e9245301e0267acc872b0e63c Mon Sep 17 00:00:00 2001 From: leejet Date: Fri, 31 Oct 2025 00:29:48 +0800 Subject: [PATCH 1/4] introduce GGMLRunnerContext --- clip.hpp | 91 ++++++------ common.hpp | 107 +++++++------- conditioner.hpp | 4 +- control.hpp | 41 +++--- esrgan.hpp | 36 ++--- flux.hpp | 346 ++++++++++++++++++++++----------------------- ggml_extend.hpp | 77 +++++----- ltxv.hpp | 2 +- mmdit.hpp | 206 +++++++++++++-------------- pmid.hpp | 124 ++++++++-------- qwen_image.hpp | 117 ++++++++------- qwenvl.hpp | 140 +++++++++--------- rope.hpp | 9 +- t5.hpp | 67 ++++----- tae.hpp | 31 ++-- unet.hpp | 100 +++++++------ vae.hpp | 84 +++++------ wan.hpp | 369 ++++++++++++++++++++++++------------------------ 18 files changed, 961 insertions(+), 990 deletions(-) diff --git a/clip.hpp b/clip.hpp index c2f195a63..dc891c77d 100644 --- a/clip.hpp +++ b/clip.hpp @@ -451,16 +451,16 @@ struct CLIPMLP : public GGMLBlock { } } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [N, n_token, d_model] auto fc1 = std::dynamic_pointer_cast(blocks["fc1"]); auto fc2 = std::dynamic_pointer_cast(blocks["fc2"]); x = fc1->forward(ctx, x); if (use_gelu) { - x = ggml_gelu_inplace(ctx, x); + x = ggml_gelu_inplace(ctx->ggml_ctx, x); } else { - x = ggml_gelu_quick_inplace(ctx, x); + x = ggml_gelu_quick_inplace(ctx->ggml_ctx, x); } x = fc2->forward(ctx, x); return x; @@ -488,15 +488,15 @@ struct CLIPLayer : public GGMLBlock { blocks["mlp"] = std::shared_ptr(new CLIPMLP(d_model, intermediate_size)); } - struct ggml_tensor* forward(struct ggml_context* ctx, ggml_backend_t backend, struct ggml_tensor* x, bool mask = true) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool mask = true) { // x: [N, n_token, d_model] auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); auto layer_norm1 = std::dynamic_pointer_cast(blocks["layer_norm1"]); auto layer_norm2 = std::dynamic_pointer_cast(blocks["layer_norm2"]); auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); - x = ggml_add(ctx, x, self_attn->forward(ctx, backend, layer_norm1->forward(ctx, x), mask)); - x = ggml_add(ctx, x, mlp->forward(ctx, layer_norm2->forward(ctx, x))); + x = ggml_add(ctx->ggml_ctx, x, self_attn->forward(ctx, layer_norm1->forward(ctx, x), mask)); + x = ggml_add(ctx->ggml_ctx, x, mlp->forward(ctx, layer_norm2->forward(ctx, x))); return x; } }; @@ -517,8 +517,7 @@ struct CLIPEncoder : public GGMLBlock { } } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, int clip_skip = -1, bool mask = true) { @@ -536,7 +535,7 @@ struct CLIPEncoder : public GGMLBlock { } std::string name = "layers." + std::to_string(i); auto layer = std::dynamic_pointer_cast(blocks[name]); - x = layer->forward(ctx, backend, x, mask); // [N, n_token, d_model] + x = layer->forward(ctx, x, mask); // [N, n_token, d_model] // LOG_DEBUG("layer %d", i); } return x; @@ -578,7 +577,7 @@ class CLIPEmbeddings : public GGMLBlock { return params["token_embedding.weight"]; } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* custom_embed_weight) { // input_ids: [N, n_token] @@ -586,12 +585,12 @@ class CLIPEmbeddings : public GGMLBlock { auto position_embed_weight = params["position_embedding.weight"]; GGML_ASSERT(input_ids->ne[0] == position_embed_weight->ne[1]); - input_ids = ggml_reshape_3d(ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]); - auto token_embedding = ggml_get_rows(ctx, custom_embed_weight != nullptr ? custom_embed_weight : token_embed_weight, input_ids); - token_embedding = ggml_reshape_3d(ctx, token_embedding, token_embedding->ne[0], token_embedding->ne[1], token_embedding->ne[3]); + input_ids = ggml_reshape_3d(ctx->ggml_ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]); + auto token_embedding = ggml_get_rows(ctx->ggml_ctx, custom_embed_weight != nullptr ? custom_embed_weight : token_embed_weight, input_ids); + token_embedding = ggml_reshape_3d(ctx->ggml_ctx, token_embedding, token_embedding->ne[0], token_embedding->ne[1], token_embedding->ne[3]); // token_embedding + position_embedding - auto x = ggml_add(ctx, + auto x = ggml_add(ctx->ggml_ctx, token_embedding, position_embed_weight); // [N, n_token, embed_dim] return x; @@ -629,7 +628,7 @@ class CLIPVisionEmbeddings : public GGMLBlock { num_positions = num_patches + 1; } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* pixel_values) { // pixel_values: [N, num_channels, image_size, image_size] // return: [N, num_positions, embed_dim] GGML_ASSERT(pixel_values->ne[0] == image_size && pixel_values->ne[1] == image_size && pixel_values->ne[2] == num_channels); @@ -641,18 +640,18 @@ class CLIPVisionEmbeddings : public GGMLBlock { // concat(patch_embedding, class_embedding) + position_embedding struct ggml_tensor* patch_embedding; int64_t N = pixel_values->ne[3]; - patch_embedding = ggml_ext_conv_2d(ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size] - patch_embedding = ggml_reshape_3d(ctx, patch_embedding, num_patches, embed_dim, N); // [N, embed_dim, num_patches] - patch_embedding = ggml_cont(ctx, ggml_permute(ctx, patch_embedding, 1, 0, 2, 3)); // [N, num_patches, embed_dim] - patch_embedding = ggml_reshape_4d(ctx, patch_embedding, 1, embed_dim, num_patches, N); // [N, num_patches, embed_dim, 1] - - struct ggml_tensor* class_embedding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, N); - class_embedding = ggml_repeat(ctx, class_embed_weight, class_embedding); // [N, embed_dim] - class_embedding = ggml_reshape_4d(ctx, class_embedding, 1, embed_dim, 1, N); // [N, 1, embed_dim, 1] - - struct ggml_tensor* x = ggml_concat(ctx, class_embedding, patch_embedding, 2); // [N, num_positions, embed_dim, 1] - x = ggml_reshape_3d(ctx, x, embed_dim, num_positions, N); // [N, num_positions, embed_dim] - x = ggml_add(ctx, x, position_embed_weight); + patch_embedding = ggml_ext_conv_2d(ctx->ggml_ctx, pixel_values, patch_embed_weight, nullptr, patch_size, patch_size); // [N, embed_dim, image_size // pacht_size, image_size // pacht_size] + patch_embedding = ggml_reshape_3d(ctx->ggml_ctx, patch_embedding, num_patches, embed_dim, N); // [N, embed_dim, num_patches] + patch_embedding = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, patch_embedding, 1, 0, 2, 3)); // [N, num_patches, embed_dim] + patch_embedding = ggml_reshape_4d(ctx->ggml_ctx, patch_embedding, 1, embed_dim, num_patches, N); // [N, num_patches, embed_dim, 1] + + struct ggml_tensor* class_embedding = ggml_new_tensor_2d(ctx->ggml_ctx, GGML_TYPE_F32, embed_dim, N); + class_embedding = ggml_repeat(ctx->ggml_ctx, class_embed_weight, class_embedding); // [N, embed_dim] + class_embedding = ggml_reshape_4d(ctx->ggml_ctx, class_embedding, 1, embed_dim, 1, N); // [N, 1, embed_dim, 1] + + struct ggml_tensor* x = ggml_concat(ctx->ggml_ctx, class_embedding, patch_embedding, 2); // [N, num_positions, embed_dim, 1] + x = ggml_reshape_3d(ctx->ggml_ctx, x, embed_dim, num_positions, N); // [N, num_positions, embed_dim] + x = ggml_add(ctx->ggml_ctx, x, position_embed_weight); return x; // [N, num_positions, embed_dim] } }; @@ -714,8 +713,7 @@ class CLIPTextModel : public GGMLBlock { return embeddings->get_token_embed_weight(); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* tkn_embeddings, size_t max_token_idx = 0, @@ -727,16 +725,16 @@ class CLIPTextModel : public GGMLBlock { auto final_layer_norm = std::dynamic_pointer_cast(blocks["final_layer_norm"]); auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size] - x = encoder->forward(ctx, backend, x, return_pooled ? -1 : clip_skip, true); + x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true); if (return_pooled || with_final_ln) { x = final_layer_norm->forward(ctx, x); } if (return_pooled) { auto text_projection = params["text_projection"]; - ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx); + ggml_tensor* pooled = ggml_view_1d(ctx->ggml_ctx, x, hidden_size, x->nb[1] * max_token_idx); if (text_projection != nullptr) { - pooled = ggml_ext_linear(ctx, pooled, text_projection, nullptr); + pooled = ggml_ext_linear(ctx->ggml_ctx, pooled, text_projection, nullptr); } else { LOG_DEBUG("identity projection"); } @@ -779,8 +777,7 @@ class CLIPVisionModel : public GGMLBlock { blocks["post_layernorm"] = std::shared_ptr(new LayerNorm(hidden_size)); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* pixel_values, bool return_pooled = true, int clip_skip = -1) { @@ -792,14 +789,14 @@ class CLIPVisionModel : public GGMLBlock { auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim] x = pre_layernorm->forward(ctx, x); - x = encoder->forward(ctx, backend, x, clip_skip, false); + x = encoder->forward(ctx, x, clip_skip, false); // print_ggml_tensor(x, true, "ClipVisionModel x: "); auto last_hidden_state = x; x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] GGML_ASSERT(x->ne[3] == 1); if (return_pooled) { - ggml_tensor* pooled = ggml_cont(ctx, ggml_view_2d(ctx, x, x->ne[0], x->ne[2], x->nb[2], 0)); + ggml_tensor* pooled = ggml_cont(ctx->ggml_ctx, ggml_view_2d(ctx->ggml_ctx, x, x->ne[0], x->ne[2], x->nb[2], 0)); return pooled; // [N, hidden_size] } else { // return x; // [N, n_token, hidden_size] @@ -831,12 +828,12 @@ class CLIPProjection : public UnaryBlock { out_features(out_features), transpose_weight(transpose_weight) {} - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { struct ggml_tensor* w = params["weight"]; if (transpose_weight) { - w = ggml_cont(ctx, ggml_transpose(ctx, w)); + w = ggml_cont(ctx->ggml_ctx, ggml_transpose(ctx->ggml_ctx, w)); } - return ggml_ext_linear(ctx, x, w, nullptr); + return ggml_ext_linear(ctx->ggml_ctx, x, w, nullptr); } }; @@ -860,8 +857,7 @@ class CLIPVisionModelProjection : public GGMLBlock { blocks["visual_projection"] = std::shared_ptr(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w)); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* pixel_values, bool return_pooled = true, int clip_skip = -1) { @@ -870,7 +866,7 @@ class CLIPVisionModelProjection : public GGMLBlock { auto vision_model = std::dynamic_pointer_cast(blocks["vision_model"]); auto visual_projection = std::dynamic_pointer_cast(blocks["visual_projection"]); - auto x = vision_model->forward(ctx, backend, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size] + auto x = vision_model->forward(ctx, pixel_values, return_pooled, clip_skip); // [N, hidden_size] or [N, n_token, hidden_size] if (return_pooled) { x = visual_projection->forward(ctx, x); // [N, projection_dim] @@ -902,8 +898,7 @@ struct CLIPTextModelRunner : public GGMLRunner { model.get_param_tensors(tensors, prefix); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* embeddings, size_t max_token_idx = 0, @@ -913,10 +908,10 @@ struct CLIPTextModelRunner : public GGMLRunner { size_t n_token = input_ids->ne[0]; if (input_ids->ne[0] > model.n_token) { GGML_ASSERT(input_ids->ne[0] % model.n_token == 0); - input_ids = ggml_reshape_2d(ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token); + input_ids = ggml_reshape_2d(ctx->ggml_ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token); } - return model.forward(ctx, backend, input_ids, embeddings, max_token_idx, return_pooled, clip_skip); + return model.forward(ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip); } struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, @@ -943,7 +938,9 @@ struct CLIPTextModelRunner : public GGMLRunner { embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1); } - struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, embeddings, max_token_idx, return_pooled, clip_skip); + auto runner_ctx = get_context(); + + struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, embeddings, max_token_idx, return_pooled, clip_skip); ggml_build_forward_expand(gf, hidden_states); diff --git a/common.hpp b/common.hpp index 785df5785..cb07e8dc0 100644 --- a/common.hpp +++ b/common.hpp @@ -23,12 +23,12 @@ class DownSampleBlock : public GGMLBlock { } } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [N, channels, h, w] if (vae_downsample) { auto conv = std::dynamic_pointer_cast(blocks["conv"]); - x = ggml_pad(ctx, x, 1, 1, 0, 0); + x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0); x = conv->forward(ctx, x); } else { auto conv = std::dynamic_pointer_cast(blocks["op"]); @@ -52,12 +52,12 @@ class UpSampleBlock : public GGMLBlock { blocks["conv"] = std::shared_ptr(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1})); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [N, channels, h, w] auto conv = std::dynamic_pointer_cast(blocks["conv"]); - x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); // [N, channels, h*2, w*2] - x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2] + x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST); // [N, channels, h*2, w*2] + x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2] return x; } }; @@ -121,7 +121,7 @@ class ResBlock : public GGMLBlock { } } - virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb = nullptr) { + virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* emb = nullptr) { // For dims==3, we reduce dimension from 5d to 4d by merging h and w, in order not to change ggml // [N, c, t, h, w] => [N, c, t, h * w] // x: [N, channels, h, w] if dims == 2 else [N, channels, t, h, w] @@ -137,32 +137,32 @@ class ResBlock : public GGMLBlock { // in_layers auto h = in_layers_0->forward(ctx, x); - h = ggml_silu_inplace(ctx, h); + h = ggml_silu_inplace(ctx->ggml_ctx, h); h = in_layers_2->forward(ctx, h); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w] // emb_layers if (!skip_t_emb) { auto emb_layer_1 = std::dynamic_pointer_cast(blocks["emb_layers.1"]); - auto emb_out = ggml_silu(ctx, emb); + auto emb_out = ggml_silu(ctx->ggml_ctx, emb); emb_out = emb_layer_1->forward(ctx, emb_out); // [N, out_channels] if dims == 2 else [N, t, out_channels] if (dims == 2) { - emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1] + emb_out = ggml_reshape_4d(ctx->ggml_ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1] } else { - emb_out = ggml_reshape_4d(ctx, emb_out, 1, emb_out->ne[0], emb_out->ne[1], emb_out->ne[2]); // [N, t, out_channels, 1] + emb_out = ggml_reshape_4d(ctx->ggml_ctx, emb_out, 1, emb_out->ne[0], emb_out->ne[1], emb_out->ne[2]); // [N, t, out_channels, 1] if (exchange_temb_dims) { // emb_out = rearrange(emb_out, "b t c ... -> b c t ...") - emb_out = ggml_cont(ctx, ggml_permute(ctx, emb_out, 0, 2, 1, 3)); // [N, out_channels, t, 1] + emb_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, emb_out, 0, 2, 1, 3)); // [N, out_channels, t, 1] } } - h = ggml_add(ctx, h, emb_out); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w] + h = ggml_add(ctx->ggml_ctx, h, emb_out); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w] } // out_layers h = out_layers_0->forward(ctx, h); - h = ggml_silu_inplace(ctx, h); + h = ggml_silu_inplace(ctx->ggml_ctx, h); // dropout, skip for inference h = out_layers_3->forward(ctx, h); @@ -172,7 +172,7 @@ class ResBlock : public GGMLBlock { x = skip_connection->forward(ctx, x); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w] } - h = ggml_add(ctx, h, x); + h = ggml_add(ctx->ggml_ctx, h, x); return h; // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w] } }; @@ -193,24 +193,24 @@ class GEGLU : public UnaryBlock { GEGLU(int64_t dim_in, int64_t dim_out) : dim_in(dim_in), dim_out(dim_out) {} - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [ne3, ne2, ne1, dim_in] // return: [ne3, ne2, ne1, dim_out] struct ggml_tensor* w = params["proj.weight"]; struct ggml_tensor* b = params["proj.bias"]; - auto x_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in] - auto x_b = ggml_view_1d(ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in] - auto gate_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ] - auto gate_b = ggml_view_1d(ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ] + auto x_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in] + auto x_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in] + auto gate_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ] + auto gate_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ] auto x_in = x; - x = ggml_ext_linear(ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out] - auto gate = ggml_ext_linear(ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out] + x = ggml_ext_linear(ctx->ggml_ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out] + auto gate = ggml_ext_linear(ctx->ggml_ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out] - gate = ggml_gelu_inplace(ctx, gate); + gate = ggml_gelu_inplace(ctx->ggml_ctx, gate); - x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, dim_out] + x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out] return x; } @@ -222,13 +222,13 @@ class GELU : public UnaryBlock { blocks["proj"] = std::shared_ptr(new Linear(dim_in, dim_out, bias)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [ne3, ne2, ne1, dim_in] // return: [ne3, ne2, ne1, dim_out] auto proj = std::dynamic_pointer_cast(blocks["proj"]); x = proj->forward(ctx, x); - x = ggml_gelu_inplace(ctx, x); + x = ggml_gelu_inplace(ctx->ggml_ctx, x); return x; } }; @@ -262,7 +262,7 @@ class FeedForward : public GGMLBlock { blocks["net.2"] = std::shared_ptr(new Linear(inner_dim, dim_out, true, false, false, scale)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [ne3, ne2, ne1, dim] // return: [ne3, ne2, ne1, dim_out] @@ -304,8 +304,7 @@ class CrossAttention : public GGMLBlock { // to_out_1 is nn.Dropout(), skip for inference } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { // x: [N, n_token, query_dim] @@ -325,7 +324,7 @@ class CrossAttention : public GGMLBlock { auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] - x = ggml_ext_attention_ext(ctx, backend, q, k, v, n_head, nullptr, false, false, flash_attn); // [N, n_token, inner_dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, false, flash_attn); // [N, n_token, inner_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] return x; @@ -364,8 +363,7 @@ class BasicTransformerBlock : public GGMLBlock { } } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { // x: [N, n_token, query_dim] @@ -387,21 +385,21 @@ class BasicTransformerBlock : public GGMLBlock { x = norm_in->forward(ctx, x); x = ff_in->forward(ctx, x); // self.is_res is always True - x = ggml_add(ctx, x, x_skip); + x = ggml_add(ctx->ggml_ctx, x, x_skip); } auto r = x; x = norm1->forward(ctx, x); - x = attn1->forward(ctx, backend, x, x); // self-attention - x = ggml_add(ctx, x, r); + x = attn1->forward(ctx, x, x); // self-attention + x = ggml_add(ctx->ggml_ctx, x, r); r = x; x = norm2->forward(ctx, x); - x = attn2->forward(ctx, backend, x, context); // cross-attention - x = ggml_add(ctx, x, r); + x = attn2->forward(ctx, x, context); // cross-attention + x = ggml_add(ctx->ggml_ctx, x, r); r = x; x = norm3->forward(ctx, x); x = ff->forward(ctx, x); - x = ggml_add(ctx, x, r); + x = ggml_add(ctx->ggml_ctx, x, r); return x; } @@ -441,8 +439,7 @@ class SpatialTransformer : public GGMLBlock { blocks["proj_out"] = std::shared_ptr(new Conv2d(inner_dim, in_channels, {1, 1})); } - virtual struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { // x: [N, in_channels, h, w] @@ -460,23 +457,23 @@ class SpatialTransformer : public GGMLBlock { x = norm->forward(ctx, x); x = proj_in->forward(ctx, x); // [N, inner_dim, h, w] - x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim] - x = ggml_reshape_3d(ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim] + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim] + x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim] for (int i = 0; i < depth; i++) { std::string name = "transformer_blocks." + std::to_string(i); auto transformer_block = std::dynamic_pointer_cast(blocks[name]); - x = transformer_block->forward(ctx, backend, x, context); + x = transformer_block->forward(ctx, x, context); } - x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w] - x = ggml_reshape_4d(ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w] + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w] + x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w] // proj_out x = proj_out->forward(ctx, x); // [N, in_channels, h, w] - x = ggml_add(ctx, x, x_in); + x = ggml_add(ctx->ggml_ctx, x, x_in); return x; } }; @@ -503,14 +500,14 @@ class AlphaBlender : public GGMLBlock { // since mix_factor.shape is [1,], we don't need rearrange using rearrange_pattern } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x_spatial, struct ggml_tensor* x_temporal) { // image_only_indicator is always tensor([0.]) float alpha = get_alpha(); - auto x = ggml_add(ctx, - ggml_scale(ctx, x_spatial, alpha), - ggml_scale(ctx, x_temporal, 1.0f - alpha)); + auto x = ggml_add(ctx->ggml_ctx, + ggml_scale(ctx->ggml_ctx, x_spatial, alpha), + ggml_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha)); return x; } }; @@ -528,7 +525,7 @@ class VideoResBlock : public ResBlock { blocks["time_mixer"] = std::shared_ptr(new AlphaBlender()); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* emb, int num_video_frames) { @@ -546,18 +543,18 @@ class VideoResBlock : public ResBlock { int64_t H = x->ne[1]; int64_t W = x->ne[0]; - x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) + x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) auto x_mix = x; - emb = ggml_reshape_4d(ctx, emb, emb->ne[0], T, B, emb->ne[3]); // (b t) ... -> b t ... + emb = ggml_reshape_4d(ctx->ggml_ctx, emb, emb->ne[0], T, B, emb->ne[3]); // (b t) ... -> b t ... x = time_stack->forward(ctx, x, emb); // b t c (h w) x = time_mixer->forward(ctx, x_mix, x); // b t c (h w) - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) - x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) + x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w return x; } diff --git a/conditioner.hpp b/conditioner.hpp index d6f6efa60..86cdfb87f 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -641,7 +641,9 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner { pixel_values = to_backend(pixel_values); - struct ggml_tensor* hidden_states = vision_model.forward(compute_ctx, runtime_backend, pixel_values, return_pooled, clip_skip); + auto runner_ctx = get_context(); + + struct ggml_tensor* hidden_states = vision_model.forward(&runner_ctx, pixel_values, return_pooled, clip_skip); ggml_build_forward_expand(gf, hidden_states); diff --git a/control.hpp b/control.hpp index 9cdf43d7f..27eee7b27 100644 --- a/control.hpp +++ b/control.hpp @@ -165,7 +165,7 @@ class ControlNetBlock : public GGMLBlock { } struct ggml_tensor* resblock_forward(std::string name, - struct ggml_context* ctx, + GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* emb) { auto block = std::dynamic_pointer_cast(blocks[name]); @@ -173,15 +173,14 @@ class ControlNetBlock : public GGMLBlock { } struct ggml_tensor* attention_layer_forward(std::string name, - struct ggml_context* ctx, - ggml_backend_t backend, + GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { auto block = std::dynamic_pointer_cast(blocks[name]); - return block->forward(ctx, backend, x, context); + return block->forward(ctx, x, context); } - struct ggml_tensor* input_hint_block_forward(struct ggml_context* ctx, + struct ggml_tensor* input_hint_block_forward(GGMLRunnerContext* ctx, struct ggml_tensor* hint, struct ggml_tensor* emb, struct ggml_tensor* context) { @@ -193,14 +192,13 @@ class ControlNetBlock : public GGMLBlock { h = block->forward(ctx, h); } else { - h = ggml_silu_inplace(ctx, h); + h = ggml_silu_inplace(ctx->ggml_ctx, h); } } return h; } - std::vector forward(struct ggml_context* ctx, - ggml_backend_t backend, + std::vector forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* hint, struct ggml_tensor* guided_hint, @@ -213,13 +211,13 @@ class ControlNetBlock : public GGMLBlock { // y: [N, adm_in_channels] or [1, adm_in_channels] if (context != nullptr) { if (context->ne[2] != x->ne[3]) { - context = ggml_repeat(ctx, context, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3])); + context = ggml_repeat(ctx->ggml_ctx, context, ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3])); } } if (y != nullptr) { if (y->ne[1] != x->ne[3]) { - y = ggml_repeat(ctx, y, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y->ne[0], x->ne[3])); + y = ggml_repeat(ctx->ggml_ctx, y, ggml_new_tensor_2d(ctx->ggml_ctx, GGML_TYPE_F32, y->ne[0], x->ne[3])); } } @@ -230,10 +228,10 @@ class ControlNetBlock : public GGMLBlock { auto middle_block_out = std::dynamic_pointer_cast(blocks["middle_block_out.0"]); - auto t_emb = ggml_ext_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels] + auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, model_channels); // [N, model_channels] auto emb = time_embed_0->forward(ctx, t_emb); - emb = ggml_silu_inplace(ctx, emb); + emb = ggml_silu_inplace(ctx->ggml_ctx, emb); emb = time_embed_2->forward(ctx, emb); // [N, time_embed_dim] // SDXL/SVD @@ -242,10 +240,10 @@ class ControlNetBlock : public GGMLBlock { auto label_embed_2 = std::dynamic_pointer_cast(blocks["label_emb.0.2"]); auto label_emb = label_embed_0->forward(ctx, y); - label_emb = ggml_silu_inplace(ctx, label_emb); + label_emb = ggml_silu_inplace(ctx->ggml_ctx, label_emb); label_emb = label_embed_2->forward(ctx, label_emb); // [N, time_embed_dim] - emb = ggml_add(ctx, emb, label_emb); // [N, time_embed_dim] + emb = ggml_add(ctx->ggml_ctx, emb, label_emb); // [N, time_embed_dim] } std::vector outs; @@ -259,7 +257,7 @@ class ControlNetBlock : public GGMLBlock { // input block 0 auto h = input_blocks_0_0->forward(ctx, x); - h = ggml_add(ctx, h, guided_hint); + h = ggml_add(ctx->ggml_ctx, h, guided_hint); outs.push_back(zero_convs_0->forward(ctx, h)); // input block 1-11 @@ -274,7 +272,7 @@ class ControlNetBlock : public GGMLBlock { h = resblock_forward(name, ctx, h, emb); // [N, mult*model_channels, h, w] if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1"; - h = attention_layer_forward(name, ctx, backend, h, context); // [N, mult*model_channels, h, w] + h = attention_layer_forward(name, ctx, h, context); // [N, mult*model_channels, h, w] } auto zero_conv = std::dynamic_pointer_cast(blocks["zero_convs." + std::to_string(input_block_idx) + ".0"]); @@ -298,9 +296,9 @@ class ControlNetBlock : public GGMLBlock { // [N, 4*model_channels, h/8, w/8] // middle_block - h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8] - h = attention_layer_forward("middle_block.1", ctx, backend, h, context); // [N, 4*model_channels, h/8, w/8] - h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8] + h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8] + h = attention_layer_forward("middle_block.1", ctx, h, context); // [N, 4*model_channels, h/8, w/8] + h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8] // out outs.push_back(middle_block_out->forward(ctx, h)); @@ -404,8 +402,9 @@ struct ControlNet : public GGMLRunner { y = to_backend(y); timesteps = to_backend(timesteps); - auto outs = control_net.forward(compute_ctx, - runtime_backend, + auto runner_ctx = get_context(); + + auto outs = control_net.forward(&runner_ctx, x, hint, guided_hint_cached ? guided_hint : nullptr, diff --git a/esrgan.hpp b/esrgan.hpp index 21689ffa4..10fbc0643 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -27,11 +27,11 @@ class ResidualDenseBlock : public GGMLBlock { blocks["conv5"] = std::shared_ptr(new Conv2d(num_feat + 4 * num_grow_ch, num_feat, {3, 3}, {1, 1}, {1, 1})); } - struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) { - return ggml_leaky_relu(ctx, x, 0.2f, true); + struct ggml_tensor* lrelu(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + return ggml_leaky_relu(ctx->ggml_ctx, x, 0.2f, true); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [n, num_feat, h, w] // return: [n, num_feat, h, w] @@ -42,16 +42,16 @@ class ResidualDenseBlock : public GGMLBlock { auto conv5 = std::dynamic_pointer_cast(blocks["conv5"]); auto x1 = lrelu(ctx, conv1->forward(ctx, x)); - auto x_cat = ggml_concat(ctx, x, x1, 2); + auto x_cat = ggml_concat(ctx->ggml_ctx, x, x1, 2); auto x2 = lrelu(ctx, conv2->forward(ctx, x_cat)); - x_cat = ggml_concat(ctx, x_cat, x2, 2); + x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x2, 2); auto x3 = lrelu(ctx, conv3->forward(ctx, x_cat)); - x_cat = ggml_concat(ctx, x_cat, x3, 2); + x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x3, 2); auto x4 = lrelu(ctx, conv4->forward(ctx, x_cat)); - x_cat = ggml_concat(ctx, x_cat, x4, 2); + x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2); auto x5 = conv5->forward(ctx, x_cat); - x5 = ggml_add(ctx, ggml_scale(ctx, x5, 0.2f), x); + x5 = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, x5, 0.2f), x); return x5; } }; @@ -64,7 +64,7 @@ class RRDB : public GGMLBlock { blocks["rdb3"] = std::shared_ptr(new ResidualDenseBlock(num_feat, num_grow_ch)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [n, num_feat, h, w] // return: [n, num_feat, h, w] @@ -76,7 +76,7 @@ class RRDB : public GGMLBlock { out = rdb2->forward(ctx, out); out = rdb3->forward(ctx, out); - out = ggml_add(ctx, ggml_scale(ctx, out, 0.2f), x); + out = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, out, 0.2f), x); return out; } }; @@ -112,11 +112,11 @@ class RRDBNet : public GGMLBlock { int get_scale() { return scale; } int get_num_block() { return num_block; } - struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) { - return ggml_leaky_relu(ctx, x, 0.2f, true); + struct ggml_tensor* lrelu(GGMLRunnerContext* ctx, struct ggml_tensor* x) { + return ggml_leaky_relu(ctx->ggml_ctx, x, 0.2f, true); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [n, num_in_ch, h, w] // return: [n, num_out_ch, h*scale, w*scale] auto conv_first = std::dynamic_pointer_cast(blocks["conv_first"]); @@ -133,14 +133,14 @@ class RRDBNet : public GGMLBlock { body_feat = block->forward(ctx, body_feat); } body_feat = conv_body->forward(ctx, body_feat); - feat = ggml_add(ctx, feat, body_feat); + feat = ggml_add(ctx->ggml_ctx, feat, body_feat); // upsample if (scale >= 2) { auto conv_up1 = std::dynamic_pointer_cast(blocks["conv_up1"]); - feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); + feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); if (scale == 4) { auto conv_up2 = std::dynamic_pointer_cast(blocks["conv_up2"]); - feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); + feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); } } // for all scales @@ -359,7 +359,9 @@ struct ESRGAN : public GGMLRunner { constexpr int kGraphNodes = 1 << 16; // 65k struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, kGraphNodes, /*grads*/ false); x = to_backend(x); - struct ggml_tensor* out = rrdb_net->forward(compute_ctx, x); + + auto runner_ctx = get_context(); + struct ggml_tensor* out = rrdb_net->forward(&runner_ctx, x); ggml_build_forward_expand(gf, out); return gf; } diff --git a/flux.hpp b/flux.hpp index 538b877f4..ce0cf4763 100644 --- a/flux.hpp +++ b/flux.hpp @@ -19,14 +19,14 @@ namespace Flux { blocks["out_layer"] = std::shared_ptr(new Linear(hidden_dim, hidden_dim, true)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [..., in_dim] // return: [..., hidden_dim] auto in_layer = std::dynamic_pointer_cast(blocks["in_layer"]); auto out_layer = std::dynamic_pointer_cast(blocks["out_layer"]); x = in_layer->forward(ctx, x); - x = ggml_silu_inplace(ctx, x); + x = ggml_silu_inplace(ctx->ggml_ctx, x); x = out_layer->forward(ctx, x); return x; } @@ -48,10 +48,10 @@ namespace Flux { : hidden_size(hidden_size), eps(eps) {} - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { struct ggml_tensor* w = params["scale"]; - x = ggml_rms_norm(ctx, x, eps); - x = ggml_mul(ctx, x, w); + x = ggml_rms_norm(ctx->ggml_ctx, x, eps); + x = ggml_mul(ctx->ggml_ctx, x, w); return x; } }; @@ -63,7 +63,7 @@ namespace Flux { blocks["key_norm"] = std::shared_ptr(new RMSNorm(dim)); } - struct ggml_tensor* query_norm(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* query_norm(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [..., dim] // return: [..., dim] auto norm = std::dynamic_pointer_cast(blocks["query_norm"]); @@ -72,7 +72,7 @@ namespace Flux { return x; } - struct ggml_tensor* key_norm(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* key_norm(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [..., dim] // return: [..., dim] auto norm = std::dynamic_pointer_cast(blocks["key_norm"]); @@ -99,39 +99,38 @@ namespace Flux { blocks["proj"] = std::shared_ptr(new Linear(dim, dim)); } - std::vector pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) { + std::vector pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto qkv = qkv_proj->forward(ctx, x); - auto qkv_vec = split_qkv(ctx, qkv); + auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; - auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); - auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); - auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); + auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); + auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); + auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k); return {q, k, v}; } - struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* post_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto proj = std::dynamic_pointer_cast(blocks["proj"]); x = proj->forward(ctx, x); // [N, n_token, dim] return x; } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask) { // x: [N, n_token, dim] // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] - auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] - x = Rope::attention(ctx, backend, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + x = Rope::attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -144,11 +143,11 @@ namespace Flux { ModulationOut(ggml_tensor* shift = nullptr, ggml_tensor* scale = nullptr, ggml_tensor* gate = nullptr) : shift(shift), scale(scale), gate(gate) {} - ModulationOut(struct ggml_context* ctx, ggml_tensor* vec, int64_t offset) { + ModulationOut(GGMLRunnerContext* ctx, ggml_tensor* vec, int64_t offset) { int64_t stride = vec->nb[1] * vec->ne[1]; - shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim] - scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim] - gate = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 2)); // [N, dim] + shift = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim] + scale = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim] + gate = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 2)); // [N, dim] } }; @@ -164,16 +163,16 @@ namespace Flux { blocks["lin"] = std::shared_ptr(new Linear(dim, dim * multiplier)); } - std::vector forward(struct ggml_context* ctx, struct ggml_tensor* vec) { + std::vector forward(GGMLRunnerContext* ctx, struct ggml_tensor* vec) { // x: [N, dim] // return: [ModulationOut, ModulationOut] auto lin = std::dynamic_pointer_cast(blocks["lin"]); - auto out = ggml_silu(ctx, vec); + auto out = ggml_silu(ctx->ggml_ctx, vec); out = lin->forward(ctx, out); // [N, multiplier*dim] - auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim] - m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim] + auto m = ggml_reshape_3d(ctx->ggml_ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim] + m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim] ModulationOut m_0 = ModulationOut(ctx, m, 0); if (is_double) { @@ -236,7 +235,7 @@ namespace Flux { blocks["txt_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); } - std::vector get_distil_img_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + std::vector get_distil_img_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) { // TODO: not hardcoded? const int single_blocks_count = 38; const int double_blocks_count = 19; @@ -245,7 +244,7 @@ namespace Flux { return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)}; } - std::vector get_distil_txt_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + std::vector get_distil_txt_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) { // TODO: not hardcoded? const int single_blocks_count = 38; const int double_blocks_count = 19; @@ -254,8 +253,7 @@ namespace Flux { return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)}; } - std::pair forward(struct ggml_context* ctx, - ggml_backend_t backend, + std::pair forward(GGMLRunnerContext* ctx, struct ggml_tensor* img, struct ggml_tensor* txt, struct ggml_tensor* vec, @@ -300,7 +298,7 @@ namespace Flux { // prepare image for attention auto img_modulated = img_norm1->forward(ctx, img); - img_modulated = Flux::modulate(ctx, img_modulated, img_mod1.shift, img_mod1.scale); + img_modulated = Flux::modulate(ctx->ggml_ctx, img_modulated, img_mod1.shift, img_mod1.scale); auto img_qkv = img_attn->pre_attention(ctx, img_modulated); // q,k,v: [N, n_img_token, n_head, d_head] auto img_q = img_qkv[0]; auto img_k = img_qkv[1]; @@ -308,55 +306,55 @@ namespace Flux { // prepare txt for attention auto txt_modulated = txt_norm1->forward(ctx, txt); - txt_modulated = Flux::modulate(ctx, txt_modulated, txt_mod1.shift, txt_mod1.scale); + txt_modulated = Flux::modulate(ctx->ggml_ctx, txt_modulated, txt_mod1.shift, txt_mod1.scale); auto txt_qkv = txt_attn->pre_attention(ctx, txt_modulated); // q,k,v: [N, n_txt_token, n_head, d_head] auto txt_q = txt_qkv[0]; auto txt_k = txt_qkv[1]; auto txt_v = txt_qkv[2]; // run actual attention - auto q = ggml_concat(ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + auto q = ggml_concat(ctx->ggml_ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] - attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] - auto txt_attn_out = ggml_view_3d(ctx, + auto attn = Rope::attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] + attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], attn->ne[1], txt->ne[1], attn->nb[1], attn->nb[2], - 0); // [n_txt_token, N, hidden_size] - txt_attn_out = ggml_cont(ctx, ggml_permute(ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size] - auto img_attn_out = ggml_view_3d(ctx, + 0); // [n_txt_token, N, hidden_size] + txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size] + auto img_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], attn->ne[1], img->ne[1], attn->nb[1], attn->nb[2], - attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] - img_attn_out = ggml_cont(ctx, ggml_permute(ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] + img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] // calculate the img bloks - img = ggml_add(ctx, img, ggml_mul(ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate)); + img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate)); - auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale)); - img_mlp_out = ggml_gelu_inplace(ctx, img_mlp_out); + auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale)); + img_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, img_mlp_out); img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out); - img = ggml_add(ctx, img, ggml_mul(ctx, img_mlp_out, img_mod2.gate)); + img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_mlp_out, img_mod2.gate)); // calculate the txt bloks - txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate)); + txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate)); - auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale)); - txt_mlp_out = ggml_gelu_inplace(ctx, txt_mlp_out); + auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx->ggml_ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale)); + txt_mlp_out = ggml_gelu_inplace(ctx->ggml_ctx, txt_mlp_out); txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out); - txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_mod2.gate)); + txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_mod2.gate)); return {img, txt}; } @@ -397,13 +395,12 @@ namespace Flux { } } - ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + ModulationOut get_distil_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) { int64_t offset = 3 * idx; return ModulationOut(ctx, vec, offset); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* vec, struct ggml_tensor* pe, @@ -424,42 +421,42 @@ namespace Flux { mod = modulation->forward(ctx, vec)[0]; } - auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); - auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] - qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token] + auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); + auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] + qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token] - auto qkv = ggml_view_3d(ctx, + auto qkv = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, qkv_mlp->ne[0], qkv_mlp->ne[1], hidden_size * 3, qkv_mlp->nb[1], qkv_mlp->nb[2], - 0); // [hidden_size * 3 , N, n_token] - qkv = ggml_cont(ctx, ggml_permute(ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3] - auto mlp = ggml_view_3d(ctx, + 0); // [hidden_size * 3 , N, n_token] + qkv = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3] + auto mlp = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, qkv_mlp->ne[0], qkv_mlp->ne[1], mlp_hidden_dim, qkv_mlp->nb[1], qkv_mlp->nb[2], - qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim , N, n_token] - mlp = ggml_cont(ctx, ggml_permute(ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim] + qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim , N, n_token] + mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim] - auto qkv_vec = split_qkv(ctx, qkv); // q,k,v: [N, n_token, hidden_size] + auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size] int64_t head_dim = hidden_size / num_heads; - auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] - auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] - auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] + auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] + auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] + auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k); - auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] + auto attn = Rope::attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] - auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] - auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] + auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, ggml_gelu_inplace(ctx->ggml_ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] + auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] - output = ggml_add(ctx, x, ggml_mul(ctx, output, mod.gate)); + output = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, output, mod.gate)); return output; } }; @@ -480,16 +477,16 @@ namespace Flux { } } - ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + ModulationOut get_distil_mod(GGMLRunnerContext* ctx, struct ggml_tensor* vec) { int64_t offset = vec->ne[2] - 2; int64_t stride = vec->nb[1] * vec->ne[1]; - auto shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim] - auto scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim] + auto shift = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim] + auto scale = ggml_view_2d(ctx->ggml_ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim] // No gate return {shift, scale, nullptr}; } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { // x: [N, n_token, hidden_size] @@ -505,16 +502,16 @@ namespace Flux { } else { auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] - m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] - m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] + m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] + m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] int64_t offset = m->nb[1] * m->ne[1]; - shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] } - x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale); + x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale); x = linear->forward(ctx, x); return x; @@ -533,7 +530,7 @@ namespace Flux { blocks["out_proj"] = std::shared_ptr(new Linear(inner_size, hidden_size, true)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto in_proj = std::dynamic_pointer_cast(blocks["in_proj"]); auto out_proj = std::dynamic_pointer_cast(blocks["out_proj"]); @@ -541,7 +538,7 @@ namespace Flux { for (int i = 0; i < n_layers; i++) { auto norm = std::dynamic_pointer_cast(blocks["norms." + std::to_string(i)]); auto embed = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); - x = ggml_add_inplace(ctx, x, embed->forward(ctx, norm->forward(ctx, x))); + x = ggml_add_inplace(ctx->ggml_ctx, x, embed->forward(ctx, norm->forward(ctx, x))); } x = out_proj->forward(ctx, x); @@ -556,7 +553,7 @@ namespace Flux { blocks["embedder.0"] = std::make_shared(in_channels + max_freqs * max_freqs, hidden_size_input); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* dct) { // x: (B, P^2, C) @@ -564,8 +561,8 @@ namespace Flux { // return: (B, P^2, hidden_size_input) auto embedder = std::dynamic_pointer_cast(blocks["embedder.0"]); - dct = ggml_repeat_4d(ctx, dct, dct->ne[0], dct->ne[1], x->ne[2], x->ne[3]); - x = ggml_concat(ctx, x, dct, 0); + dct = ggml_repeat_4d(ctx->ggml_ctx, dct, dct->ne[0], dct->ne[1], x->ne[2], x->ne[3]); + x = ggml_concat(ctx->ggml_ctx, x, dct, 0); x = embedder->forward(ctx, x); return x; @@ -583,7 +580,7 @@ namespace Flux { blocks["norm"] = std::make_shared(hidden_size_x); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* s) { // x: (batch_size, n_token, hidden_size_x) @@ -596,31 +593,31 @@ namespace Flux { int64_t hidden_size_x = x->ne[0]; auto mlp_params = param_generator->forward(ctx, s); - auto fc_params = ggml_ext_chunk(ctx, mlp_params, 3, 0); - auto fc1_gate = ggml_reshape_3d(ctx, fc_params[0], hidden_size_x * mlp_ratio, hidden_size_x, batch_size); - auto fc1_value = ggml_reshape_3d(ctx, fc_params[1], hidden_size_x * mlp_ratio, hidden_size_x, batch_size); - auto fc2 = ggml_reshape_3d(ctx, fc_params[2], hidden_size_x, mlp_ratio * hidden_size_x, batch_size); - - fc1_gate = ggml_cont(ctx, ggml_ext_torch_permute(ctx, fc1_gate, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x] - fc1_gate = ggml_l2_norm(ctx, fc1_gate, 1e-12f); - fc1_value = ggml_cont(ctx, ggml_ext_torch_permute(ctx, fc1_value, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x] - fc1_value = ggml_l2_norm(ctx, fc1_value, 1e-12f); - fc2 = ggml_cont(ctx, ggml_ext_torch_permute(ctx, fc2, 1, 0, 2, 3)); // [batch_size, hidden_size_x, hidden_size_x*mlp_ratio] - fc2 = ggml_l2_norm(ctx, fc2, 1e-12f); + auto fc_params = ggml_ext_chunk(ctx->ggml_ctx, mlp_params, 3, 0); + auto fc1_gate = ggml_reshape_3d(ctx->ggml_ctx, fc_params[0], hidden_size_x * mlp_ratio, hidden_size_x, batch_size); + auto fc1_value = ggml_reshape_3d(ctx->ggml_ctx, fc_params[1], hidden_size_x * mlp_ratio, hidden_size_x, batch_size); + auto fc2 = ggml_reshape_3d(ctx->ggml_ctx, fc_params[2], hidden_size_x, mlp_ratio * hidden_size_x, batch_size); + + fc1_gate = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, fc1_gate, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x] + fc1_gate = ggml_l2_norm(ctx->ggml_ctx, fc1_gate, 1e-12f); + fc1_value = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, fc1_value, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x] + fc1_value = ggml_l2_norm(ctx->ggml_ctx, fc1_value, 1e-12f); + fc2 = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, fc2, 1, 0, 2, 3)); // [batch_size, hidden_size_x, hidden_size_x*mlp_ratio] + fc2 = ggml_l2_norm(ctx->ggml_ctx, fc2, 1e-12f); auto res_x = x; x = norm->forward(ctx, x); // [batch_size, n_token, hidden_size_x] - auto x1 = ggml_mul_mat(ctx, fc1_gate, x); // [batch_size, n_token, hidden_size_x*mlp_ratio] - x1 = ggml_silu_inplace(ctx, x1); + auto x1 = ggml_mul_mat(ctx->ggml_ctx, fc1_gate, x); // [batch_size, n_token, hidden_size_x*mlp_ratio] + x1 = ggml_silu_inplace(ctx->ggml_ctx, x1); - auto x2 = ggml_mul_mat(ctx, fc1_value, x); // [batch_size, n_token, hidden_size_x*mlp_ratio] + auto x2 = ggml_mul_mat(ctx->ggml_ctx, fc1_value, x); // [batch_size, n_token, hidden_size_x*mlp_ratio] - x = ggml_mul_inplace(ctx, x1, x2); // [batch_size, n_token, hidden_size_x*mlp_ratio] + x = ggml_mul_inplace(ctx->ggml_ctx, x1, x2); // [batch_size, n_token, hidden_size_x*mlp_ratio] - x = ggml_mul_mat(ctx, fc2, x); // [batch_size, n_token, hidden_size_x] + x = ggml_mul_mat(ctx->ggml_ctx, fc2, x); // [batch_size, n_token, hidden_size_x] - x = ggml_add_inplace(ctx, x, res_x); + x = ggml_add_inplace(ctx->ggml_ctx, x, res_x); return x; } @@ -633,7 +630,7 @@ namespace Flux { blocks["linear"] = std::make_shared(hidden_size, out_channels); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto linear = std::dynamic_pointer_cast(blocks["linear"]); @@ -652,15 +649,15 @@ namespace Flux { blocks["conv"] = std::make_shared(hidden_size, out_channels, std::pair{3, 3}, std::pair{1, 1}, std::pair{1, 1}); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [N, C, H, W] auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto conv = std::dynamic_pointer_cast(blocks["conv"]); - x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, H, W, C] + x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [N, H, W, C] x = norm->forward(ctx, x); - x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, H, W] + x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, C, H, W] x = conv->forward(ctx, x); return x; @@ -828,8 +825,7 @@ namespace Flux { return x; } - struct ggml_tensor* forward_orig(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx, struct ggml_tensor* img, struct ggml_tensor* txt, struct ggml_tensor* timesteps, @@ -851,41 +847,41 @@ namespace Flux { if (params.is_chroma) { int64_t mod_index_length = 344; auto approx = std::dynamic_pointer_cast(blocks["distilled_guidance_layer"]); - auto distill_timestep = ggml_ext_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f); - auto distill_guidance = ggml_ext_timestep_embedding(ctx, guidance, 16, 10000, 1000.f); + auto distill_timestep = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 16, 10000, 1000.f); + auto distill_guidance = ggml_ext_timestep_embedding(ctx->ggml_ctx, guidance, 16, 10000, 1000.f); // auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // ggml_arange tot working on a lot of backends, precomputing it on CPU instead GGML_ASSERT(mod_index_arange != nullptr); - auto modulation_index = ggml_ext_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32] + auto modulation_index = ggml_ext_timestep_embedding(ctx->ggml_ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32] // Batch broadcast (will it ever be useful) - modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32] + modulation_index = ggml_repeat(ctx->ggml_ctx, modulation_index, ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32] - auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32] - timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32] + auto timestep_guidance = ggml_concat(ctx->ggml_ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32] + timestep_guidance = ggml_repeat(ctx->ggml_ctx, timestep_guidance, modulation_index); // [N, 344, 32] - vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64] + vec = ggml_concat(ctx->ggml_ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64] // Permute for consistency with non-distilled modulation implementation - vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64] - vec = approx->forward(ctx, vec); // [344, N, hidden_size] + vec = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, vec, 0, 2, 1, 3)); // [344, N, 64] + vec = approx->forward(ctx, vec); // [344, N, hidden_size] if (y != nullptr) { - txt_img_mask = ggml_pad(ctx, y, img->ne[1], 0, 0, 0); + txt_img_mask = ggml_pad(ctx->ggml_ctx, y, img->ne[1], 0, 0, 0); } } else { auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); - vec = time_in->forward(ctx, ggml_ext_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); + vec = time_in->forward(ctx, ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1000.f)); if (params.guidance_embed) { GGML_ASSERT(guidance != nullptr); auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); // bf16 and fp16 result is different - auto g_in = ggml_ext_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); - vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); + auto g_in = ggml_ext_timestep_embedding(ctx->ggml_ctx, guidance, 256, 10000, 1000.f); + vec = ggml_add(ctx->ggml_ctx, vec, guidance_in->forward(ctx, g_in)); } - vec = ggml_add(ctx, vec, vector_in->forward(ctx, y)); + vec = ggml_add(ctx->ggml_ctx, vec, vector_in->forward(ctx, y)); } txt = txt_in->forward(ctx, txt); @@ -897,31 +893,31 @@ namespace Flux { auto block = std::dynamic_pointer_cast(blocks["double_blocks." + std::to_string(i)]); - auto img_txt = block->forward(ctx, backend, img, txt, vec, pe, txt_img_mask); + auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask); img = img_txt.first; // [N, n_img_token, hidden_size] txt = img_txt.second; // [N, n_txt_token, hidden_size] } - auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size] + auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size] for (int i = 0; i < params.depth_single_blocks; i++) { if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) { continue; } auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); - txt_img = block->forward(ctx, backend, txt_img, vec, pe, txt_img_mask); + txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask); } - txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] - img = ggml_view_3d(ctx, + txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + img = ggml_view_3d(ctx->ggml_ctx, txt_img, txt_img->ne[0], txt_img->ne[1], img->ne[1], txt_img->nb[1], txt_img->nb[2], - txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] - img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] + img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] if (final_layer) { img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) @@ -930,8 +926,7 @@ namespace Flux { return img; } - struct ggml_tensor* forward_chroma_radiance(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward_chroma_radiance(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, @@ -952,32 +947,32 @@ namespace Flux { int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - auto img = pad_to_patch_size(ctx, x); + auto img = pad_to_patch_size(ctx->ggml_ctx, x); auto orig_img = img; auto img_in_patch = std::dynamic_pointer_cast(blocks["img_in_patch"]); - img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size] - img = ggml_reshape_3d(ctx, img, img->ne[0] * img->ne[1], img->ne[2], img->ne[3]); // [N, hidden_size, H/patch_size*W/patch_size] - img = ggml_cont(ctx, ggml_ext_torch_permute(ctx, img, 1, 0, 2, 3)); // [N, H/patch_size*W/patch_size, hidden_size] + img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size] + img = ggml_reshape_3d(ctx->ggml_ctx, img, img->ne[0] * img->ne[1], img->ne[2], img->ne[3]); // [N, hidden_size, H/patch_size*W/patch_size] + img = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img, 1, 0, 2, 3)); // [N, H/patch_size*W/patch_size, hidden_size] - auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, n_img_token, hidden_size] + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, n_img_token, hidden_size] // nerf decode auto nerf_image_embedder = std::dynamic_pointer_cast(blocks["nerf_image_embedder"]); auto nerf_final_layer_conv = std::dynamic_pointer_cast(blocks["nerf_final_layer_conv"]); - auto nerf_pixels = patchify(ctx, orig_img); // [N, num_patches, C * patch_size * patch_size] + auto nerf_pixels = patchify(ctx->ggml_ctx, orig_img); // [N, num_patches, C * patch_size * patch_size] int64_t num_patches = nerf_pixels->ne[1]; - nerf_pixels = ggml_reshape_3d(ctx, + nerf_pixels = ggml_reshape_3d(ctx->ggml_ctx, nerf_pixels, nerf_pixels->ne[0] / C, C, - nerf_pixels->ne[1] * nerf_pixels->ne[2]); // [N*num_patches, C, patch_size*patch_size] - nerf_pixels = ggml_cont(ctx, ggml_ext_torch_permute(ctx, nerf_pixels, 1, 0, 2, 3)); // [N*num_patches, patch_size*patch_size, C] + nerf_pixels->ne[1] * nerf_pixels->ne[2]); // [N*num_patches, C, patch_size*patch_size] + nerf_pixels = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, nerf_pixels, 1, 0, 2, 3)); // [N*num_patches, patch_size*patch_size, C] - auto nerf_hidden = ggml_reshape_2d(ctx, out, out->ne[0], out->ne[1] * out->ne[2]); // [N*num_patches, hidden_size] - auto img_dct = nerf_image_embedder->forward(ctx, nerf_pixels, dct); // [N*num_patches, patch_size*patch_size, nerf_hidden_size] + auto nerf_hidden = ggml_reshape_2d(ctx->ggml_ctx, out, out->ne[0], out->ne[1] * out->ne[2]); // [N*num_patches, hidden_size] + auto img_dct = nerf_image_embedder->forward(ctx, nerf_pixels, dct); // [N*num_patches, patch_size*patch_size, nerf_hidden_size] for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) { auto block = std::dynamic_pointer_cast(blocks["nerf_blocks." + std::to_string(i)]); @@ -985,17 +980,16 @@ namespace Flux { img_dct = block->forward(ctx, img_dct, nerf_hidden); } - img_dct = ggml_cont(ctx, ggml_ext_torch_permute(ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size] - img_dct = ggml_reshape_3d(ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size] - img_dct = unpatchify(ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W] + img_dct = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size] + img_dct = ggml_reshape_3d(ctx->ggml_ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size] + img_dct = unpatchify(ctx->ggml_ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W] out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W] return out; } - struct ggml_tensor* forward_flux_chroma(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward_flux_chroma(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, @@ -1016,58 +1010,57 @@ namespace Flux { int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - auto img = process_img(ctx, x); + auto img = process_img(ctx->ggml_ctx, x); uint64_t img_tokens = img->ne[1]; if (params.version == VERSION_FLUX_FILL) { GGML_ASSERT(c_concat != nullptr); - ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); - ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); + ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); + ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); - masked = process_img(ctx, masked); - mask = process_img(ctx, mask); + masked = process_img(ctx->ggml_ctx, masked); + mask = process_img(ctx->ggml_ctx, mask); - img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0); + img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, masked, mask, 0), 0); } else if (params.version == VERSION_FLEX_2) { GGML_ASSERT(c_concat != nullptr); - ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); - ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); - ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1)); + ggml_tensor* masked = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); + ggml_tensor* mask = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); + ggml_tensor* control = ggml_view_4d(ctx->ggml_ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1)); - masked = process_img(ctx, masked); - mask = process_img(ctx, mask); - control = process_img(ctx, control); + masked = process_img(ctx->ggml_ctx, masked); + mask = process_img(ctx->ggml_ctx, mask); + control = process_img(ctx->ggml_ctx, control); - img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0); + img = ggml_concat(ctx->ggml_ctx, img, ggml_concat(ctx->ggml_ctx, ggml_concat(ctx->ggml_ctx, masked, mask, 0), control, 0), 0); } else if (params.version == VERSION_FLUX_CONTROLS) { GGML_ASSERT(c_concat != nullptr); - auto control = process_img(ctx, c_concat); - img = ggml_concat(ctx, img, control, 0); + auto control = process_img(ctx->ggml_ctx, c_concat); + img = ggml_concat(ctx->ggml_ctx, img, control, 0); } if (ref_latents.size() > 0) { for (ggml_tensor* ref : ref_latents) { - ref = process_img(ctx, ref); - img = ggml_concat(ctx, img, ref, 1); + ref = process_img(ctx->ggml_ctx, ref); + img = ggml_concat(ctx->ggml_ctx, img, ref, 1); } } - auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size] + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size] if (out->ne[1] > img_tokens) { - out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] - out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); - out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] + out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] + out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); + out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] } // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) - out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, C, H + pad_h, W + pad_w] + out = unpatchify(ctx->ggml_ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, C, H + pad_h, W + pad_w] return out; } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, @@ -1091,7 +1084,6 @@ namespace Flux { if (params.version == VERSION_CHROMA_RADIANCE) { return forward_chroma_radiance(ctx, - backend, x, timestep, context, @@ -1105,7 +1097,6 @@ namespace Flux { skip_layers); } else { return forward_flux_chroma(ctx, - backend, x, timestep, context, @@ -1323,8 +1314,9 @@ namespace Flux { set_backend_tensor_data(dct, dct_vec.data()); } - struct ggml_tensor* out = flux.forward(compute_ctx, - runtime_backend, + auto runner_ctx = get_context(); + + struct ggml_tensor* out = flux.forward(&runner_ctx, x, timesteps, context, diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 625542e00..13c5e2be0 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1462,6 +1462,11 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) { typedef std::map String2GGMLType; +struct GGMLRunnerContext { + ggml_backend_t backend = nullptr; + ggml_context* ggml_ctx = nullptr; +}; + struct GGMLRunner { protected: typedef std::function get_graph_cb_t; @@ -1744,6 +1749,13 @@ struct GGMLRunner { free_cache_ctx_and_buffer(); } + virtual GGMLRunnerContext get_context() { + GGMLRunnerContext runner_ctx; + runner_ctx.ggml_ctx = compute_ctx; + runner_ctx.backend = runtime_backend; + return runner_ctx; + } + void reset_compute_ctx() { free_compute_ctx(); alloc_compute_ctx(); @@ -1955,12 +1967,12 @@ class GGMLBlock { class UnaryBlock : public GGMLBlock { public: - virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) = 0; + virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) = 0; }; class Identity : public UnaryBlock { public: - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { return x; } }; @@ -1974,7 +1986,7 @@ class Linear : public UnaryBlock { bool force_prec_f32; float scale; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); if (in_features % ggml_blck_size(wtype) != 0 || force_f32) { wtype = GGML_TYPE_F32; @@ -2000,13 +2012,13 @@ class Linear : public UnaryBlock { force_prec_f32(force_prec_f32), scale(scale) {} - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = params["weight"]; struct ggml_tensor* b = nullptr; if (bias) { b = params["bias"]; } - return ggml_ext_linear(ctx, x, w, b, force_prec_f32, scale); + return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale); } }; @@ -2022,7 +2034,7 @@ class Embedding : public UnaryBlock { protected: int64_t embedding_dim; int64_t num_embeddings; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") { + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override { enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); if (!support_get_rows(wtype)) { wtype = GGML_TYPE_F32; @@ -2036,7 +2048,7 @@ class Embedding : public UnaryBlock { num_embeddings(num_embeddings) { } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids) { // input_ids: [N, n_token] auto weight = params["weight"]; @@ -2044,11 +2056,11 @@ class Embedding : public UnaryBlock { // There are issues with ggml batch inference, so we are expanding it here first. // TODO: fix ggml batch inference int64_t n = input_ids->ne[1]; - input_ids = ggml_reshape_1d(ctx, input_ids, input_ids->ne[0] * input_ids->ne[1]); + input_ids = ggml_reshape_1d(ctx->ggml_ctx, input_ids, input_ids->ne[0] * input_ids->ne[1]); - input_ids = ggml_reshape_3d(ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]); - auto embedding = ggml_get_rows(ctx, weight, input_ids); - embedding = ggml_reshape_3d(ctx, embedding, embedding->ne[0], embedding->ne[1] / n, n); + input_ids = ggml_reshape_3d(ctx->ggml_ctx, input_ids, input_ids->ne[0], 1, input_ids->ne[1]); + auto embedding = ggml_get_rows(ctx->ggml_ctx, weight, input_ids); + embedding = ggml_reshape_3d(ctx->ggml_ctx, embedding, embedding->ne[0], embedding->ne[1] / n, n); // [N, n_token, embedding_dim] return embedding; @@ -2067,7 +2079,7 @@ class Conv2d : public UnaryBlock { bool direct = false; float scale = 1.f; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") { + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override { enum ggml_type wtype = GGML_TYPE_F16; params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels); if (bias) { @@ -2104,13 +2116,13 @@ class Conv2d : public UnaryBlock { return "Conv2d"; } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = params["weight"]; struct ggml_tensor* b = nullptr; if (bias) { b = params["bias"]; } - return ggml_ext_conv_2d(ctx, + return ggml_ext_conv_2d(ctx->ggml_ctx, x, w, b, @@ -2135,7 +2147,7 @@ class Conv3dnx1x1 : public UnaryBlock { int64_t dilation; bool bias; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") { + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override { enum ggml_type wtype = GGML_TYPE_F16; params["weight"] = ggml_new_tensor_4d(ctx, wtype, 1, kernel_size, in_channels, out_channels); // 5d => 4d if (bias) { @@ -2162,13 +2174,13 @@ class Conv3dnx1x1 : public UnaryBlock { // x: [N, IC, ID, IH*IW] // result: [N, OC, OD, OH*OW] - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = params["weight"]; struct ggml_tensor* b = nullptr; if (bias) { b = params["bias"]; } - return ggml_ext_conv_3d_nx1x1(ctx, x, w, b, stride, padding, dilation); + return ggml_ext_conv_3d_nx1x1(ctx->ggml_ctx, x, w, b, stride, padding, dilation); } }; @@ -2182,7 +2194,7 @@ class Conv3d : public UnaryBlock { std::tuple dilation; bool bias; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") { + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override { enum ggml_type wtype = GGML_TYPE_F16; params["weight"] = ggml_new_tensor_4d(ctx, wtype, @@ -2211,13 +2223,13 @@ class Conv3d : public UnaryBlock { dilation(dilation), bias(bias) {} - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = params["weight"]; struct ggml_tensor* b = nullptr; if (bias) { b = params["bias"]; } - return ggml_ext_conv_3d(ctx, x, w, b, in_channels, + return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels, std::get<2>(stride), std::get<1>(stride), std::get<0>(stride), std::get<2>(padding), std::get<1>(padding), std::get<0>(padding), std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation)); @@ -2231,7 +2243,7 @@ class LayerNorm : public UnaryBlock { bool elementwise_affine; bool bias; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { if (elementwise_affine) { enum ggml_type wtype = GGML_TYPE_F32; params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape); @@ -2252,7 +2264,7 @@ class LayerNorm : public UnaryBlock { elementwise_affine(elementwise_affine), bias(bias) {} - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = nullptr; struct ggml_tensor* b = nullptr; @@ -2262,7 +2274,7 @@ class LayerNorm : public UnaryBlock { b = params["bias"]; } } - return ggml_ext_layer_norm(ctx, x, w, b, eps); + return ggml_ext_layer_norm(ctx->ggml_ctx, x, w, b, eps); } }; @@ -2273,7 +2285,7 @@ class GroupNorm : public GGMLBlock { float eps; bool affine; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") override { if (affine) { enum ggml_type wtype = GGML_TYPE_F32; enum ggml_type bias_wtype = GGML_TYPE_F32; @@ -2292,14 +2304,14 @@ class GroupNorm : public GGMLBlock { eps(eps), affine(affine) {} - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = nullptr; struct ggml_tensor* b = nullptr; if (affine) { w = params["weight"]; b = params["bias"]; } - return ggml_ext_group_norm(ctx, x, w, b, num_groups); + return ggml_ext_group_norm(ctx->ggml_ctx, x, w, b, num_groups); } }; @@ -2314,7 +2326,7 @@ class RMSNorm : public UnaryBlock { int64_t hidden_size; float eps; - void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") { + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override { enum ggml_type wtype = GGML_TYPE_F32; params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } @@ -2325,10 +2337,10 @@ class RMSNorm : public UnaryBlock { : hidden_size(hidden_size), eps(eps) {} - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = params["weight"]; - x = ggml_rms_norm(ctx, x, eps); - x = ggml_mul_inplace(ctx, x, w); + x = ggml_rms_norm(ctx->ggml_ctx, x, eps); + x = ggml_mul_inplace(ctx->ggml_ctx, x, w); return x; } }; @@ -2364,8 +2376,7 @@ class MultiheadAttention : public GGMLBlock { } // x: [N, n_token, embed_dim] - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool mask = false) { auto q_proj = std::dynamic_pointer_cast(blocks[q_proj_name]); @@ -2377,7 +2388,7 @@ class MultiheadAttention : public GGMLBlock { struct ggml_tensor* k = k_proj->forward(ctx, x); struct ggml_tensor* v = v_proj->forward(ctx, x); - x = ggml_ext_attention_ext(ctx, backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim] x = out_proj->forward(ctx, x); // [N, n_token, embed_dim] return x; diff --git a/ltxv.hpp b/ltxv.hpp index fdd190f02..0a2877a86 100644 --- a/ltxv.hpp +++ b/ltxv.hpp @@ -27,7 +27,7 @@ namespace LTXV { bias)); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool causal = true) { // x: [N*IC, ID, IH, IW] diff --git a/mmdit.hpp b/mmdit.hpp index f73c3c57b..2165c1483 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -27,13 +27,13 @@ struct Mlp : public GGMLBlock { blocks["fc2"] = std::shared_ptr(new Linear(hidden_features, out_features, bias)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [N, n_token, in_features] auto fc1 = std::dynamic_pointer_cast(blocks["fc1"]); auto fc2 = std::dynamic_pointer_cast(blocks["fc2"]); x = fc1->forward(ctx, x); - x = ggml_gelu_inplace(ctx, x); + x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = fc2->forward(ctx, x); return x; } @@ -72,7 +72,7 @@ struct PatchEmbed : public GGMLBlock { bias)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [N, C, H, W] // return: [N, H*W, embed_dim] auto proj = std::dynamic_pointer_cast(blocks["proj"]); @@ -82,13 +82,13 @@ struct PatchEmbed : public GGMLBlock { int64_t H = x->ne[1]; int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // TODO: reflect pad mode + x = ggml_pad(ctx->ggml_ctx, x, pad_w, pad_h, 0, 0); // TODO: reflect pad mode } x = proj->forward(ctx, x); if (flatten) { - x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]); - x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); + x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]); + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); } return x; } @@ -107,16 +107,16 @@ struct TimestepEmbedder : public GGMLBlock { blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, hidden_size, true, true)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* t) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* t) { // t: [N, ] // return: [N, hidden_size] auto mlp_0 = std::dynamic_pointer_cast(blocks["mlp.0"]); auto mlp_2 = std::dynamic_pointer_cast(blocks["mlp.2"]); - auto t_freq = ggml_ext_timestep_embedding(ctx, t, frequency_embedding_size); // [N, frequency_embedding_size] + auto t_freq = ggml_ext_timestep_embedding(ctx->ggml_ctx, t, frequency_embedding_size); // [N, frequency_embedding_size] auto t_emb = mlp_0->forward(ctx, t_freq); - t_emb = ggml_silu_inplace(ctx, t_emb); + t_emb = ggml_silu_inplace(ctx->ggml_ctx, t_emb); t_emb = mlp_2->forward(ctx, t_emb); return t_emb; } @@ -131,14 +131,14 @@ struct VectorEmbedder : public GGMLBlock { blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, hidden_size, true, true)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [N, input_dim] // return: [N, hidden_size] auto mlp_0 = std::dynamic_pointer_cast(blocks["mlp.0"]); auto mlp_2 = std::dynamic_pointer_cast(blocks["mlp.2"]); x = mlp_0->forward(ctx, x); - x = ggml_silu_inplace(ctx, x); + x = ggml_silu_inplace(ctx->ggml_ctx, x); x = mlp_2->forward(ctx, x); return x; } @@ -173,15 +173,15 @@ class SelfAttention : public GGMLBlock { } } - std::vector pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) { + std::vector pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); auto qkv = qkv_proj->forward(ctx, x); - auto qkv_vec = split_qkv(ctx, qkv); + auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; - auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] - auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] - auto v = qkv_vec[2]; // [N, n_token, n_head*d_head] + auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] + auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] + auto v = qkv_vec[2]; // [N, n_token, n_head*d_head] if (qk_norm == "rms" || qk_norm == "ln") { auto ln_q = std::dynamic_pointer_cast(blocks["ln_q"]); @@ -190,13 +190,13 @@ class SelfAttention : public GGMLBlock { k = ln_k->forward(ctx, k); } - q = ggml_reshape_3d(ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); // [N, n_token, n_head*d_head] - k = ggml_reshape_3d(ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); // [N, n_token, n_head*d_head] + q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); // [N, n_token, n_head*d_head] + k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); // [N, n_token, n_head*d_head] return {q, k, v}; } - struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* post_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x) { GGML_ASSERT(!pre_only); auto proj = std::dynamic_pointer_cast(blocks["proj"]); @@ -206,12 +206,11 @@ class SelfAttention : public GGMLBlock { } // x: [N, n_token, dim] - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto qkv = pre_attention(ctx, x); - x = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, true); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, true); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -274,9 +273,9 @@ struct DismantledBlock : public GGMLBlock { blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, n_mods * hidden_size)); } - std::tuple, std::vector, std::vector> pre_attention_x(struct ggml_context* ctx, - struct ggml_tensor* x, - struct ggml_tensor* c) { + std::tuple, std::vector, std::vector> pre_attention_x(GGMLRunnerContext* ctx, + struct ggml_tensor* x, + struct ggml_tensor* c) { GGML_ASSERT(self_attn); // x: [N, n_token, hidden_size] // c: [N, hidden_size] @@ -286,35 +285,35 @@ struct DismantledBlock : public GGMLBlock { auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); int64_t n_mods = 9; - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size] - m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] - m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size] + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] + m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] + m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size] int64_t offset = m->nb[1] * m->ne[1]; - auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] - auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] + auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] - auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] - auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] - auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] + auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] + auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] + auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] - auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size] - auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size] - auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size] + auto shift_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size] + auto scale_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size] + auto gate_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size] auto x_norm = norm1->forward(ctx, x); - auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa); + auto attn_in = modulate(ctx->ggml_ctx, x_norm, shift_msa, scale_msa); auto qkv = attn->pre_attention(ctx, attn_in); - auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2); + auto attn2_in = modulate(ctx->ggml_ctx, x_norm, shift_msa2, scale_msa2); auto qkv2 = attn2->pre_attention(ctx, attn2_in); return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}}; } - std::pair, std::vector> pre_attention(struct ggml_context* ctx, + std::pair, std::vector> pre_attention(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { // x: [N, n_token, hidden_size] @@ -327,33 +326,33 @@ struct DismantledBlock : public GGMLBlock { if (pre_only) { n_mods = 2; } - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size] - m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] - m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size] + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] + m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] + m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size] int64_t offset = m->nb[1] * m->ne[1]; - auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] if (!pre_only) { - auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] - auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] - auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] - auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] + auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] + auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] + auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] + auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] - auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa); + auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa); auto qkv = attn->pre_attention(ctx, attn_in); return {qkv, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp}}; } else { - auto attn_in = modulate(ctx, norm1->forward(ctx, x), shift_msa, scale_msa); + auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa); auto qkv = attn->pre_attention(ctx, attn_in); return {qkv, {nullptr, nullptr, nullptr, nullptr, nullptr}}; } } - struct ggml_tensor* post_attention_x(struct ggml_context* ctx, + struct ggml_tensor* post_attention_x(GGMLRunnerContext* ctx, struct ggml_tensor* attn_out, struct ggml_tensor* attn2_out, struct ggml_tensor* x, @@ -376,22 +375,22 @@ struct DismantledBlock : public GGMLBlock { auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); - gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size] - gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size] - gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size] + gate_msa = ggml_reshape_3d(ctx->ggml_ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size] + gate_mlp = ggml_reshape_3d(ctx->ggml_ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size] + gate_msa2 = ggml_reshape_3d(ctx->ggml_ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size] attn_out = attn->post_attention(ctx, attn_out); attn2_out = attn2->post_attention(ctx, attn2_out); - x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa)); - x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2)); - auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); - x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp)); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa)); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn2_out, gate_msa2)); + auto mlp_out = mlp->forward(ctx, modulate(ctx->ggml_ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, mlp_out, gate_mlp)); return x; } - struct ggml_tensor* post_attention(struct ggml_context* ctx, + struct ggml_tensor* post_attention(GGMLRunnerContext* ctx, struct ggml_tensor* attn_out, struct ggml_tensor* x, struct ggml_tensor* gate_msa, @@ -411,20 +410,19 @@ struct DismantledBlock : public GGMLBlock { auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); - gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size] - gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size] + gate_msa = ggml_reshape_3d(ctx->ggml_ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size] + gate_mlp = ggml_reshape_3d(ctx->ggml_ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size] attn_out = attn->post_attention(ctx, attn_out); - x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa)); - auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); - x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp)); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, attn_out, gate_msa)); + auto mlp_out = mlp->forward(ctx, modulate(ctx->ggml_ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); + x = ggml_add(ctx->ggml_ctx, x, ggml_mul(ctx->ggml_ctx, mlp_out, gate_mlp)); return x; } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { // x: [N, n_token, hidden_size] @@ -441,8 +439,8 @@ struct DismantledBlock : public GGMLBlock { auto qkv2 = std::get<1>(qkv_intermediates); auto intermediates = std::get<2>(qkv_intermediates); - auto attn_out = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] - auto attn2_out = ggml_ext_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] + auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] + auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] x = post_attention_x(ctx, attn_out, attn2_out, @@ -458,7 +456,7 @@ struct DismantledBlock : public GGMLBlock { auto qkv = qkv_intermediates.first; auto intermediates = qkv_intermediates.second; - auto attn_out = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] + auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] x = post_attention(ctx, attn_out, intermediates[0], @@ -472,8 +470,7 @@ struct DismantledBlock : public GGMLBlock { }; __STATIC_INLINE__ std::pair -block_mixing(struct ggml_context* ctx, - ggml_backend_t backend, +block_mixing(GGMLRunnerContext* ctx, bool flash_attn, struct ggml_tensor* context, struct ggml_tensor* x, @@ -501,29 +498,29 @@ block_mixing(struct ggml_context* ctx, } std::vector qkv; for (int i = 0; i < 3; i++) { - qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1)); + qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1)); } - auto attn = ggml_ext_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, flash_attn); // [N, n_context + n_token, hidden_size] - attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] - auto context_attn = ggml_view_3d(ctx, + auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, flash_attn); // [N, n_context + n_token, hidden_size] + attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] + auto context_attn = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], attn->ne[1], context->ne[1], attn->nb[1], attn->nb[2], - 0); // [n_context, N, hidden_size] - context_attn = ggml_cont(ctx, ggml_permute(ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size] - auto x_attn = ggml_view_3d(ctx, + 0); // [n_context, N, hidden_size] + context_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size] + auto x_attn = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], attn->ne[1], x->ne[1], attn->nb[1], attn->nb[2], - attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size] - x_attn = ggml_cont(ctx, ggml_permute(ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size] + attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size] + x_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size] if (!context_block->pre_only) { context = context_block->post_attention(ctx, @@ -538,7 +535,7 @@ block_mixing(struct ggml_context* ctx, } if (x_block->self_attn) { - auto attn2 = ggml_ext_attention_ext(ctx, backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size] + auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size] x = x_block->post_attention_x(ctx, x_attn, @@ -579,15 +576,14 @@ struct JointBlock : public GGMLBlock { blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x, flash_attn)); } - std::pair forward(struct ggml_context* ctx, - ggml_backend_t backend, + std::pair forward(GGMLRunnerContext* ctx, struct ggml_tensor* context, struct ggml_tensor* x, struct ggml_tensor* c) { auto context_block = std::dynamic_pointer_cast(blocks["context_block"]); auto x_block = std::dynamic_pointer_cast(blocks["x_block"]); - return block_mixing(ctx, backend, flash_attn, context, x, c, context_block, x_block); + return block_mixing(ctx, flash_attn, context, x, c, context_block, x_block); } }; @@ -603,7 +599,7 @@ struct FinalLayer : public GGMLBlock { blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { // x: [N, n_token, hidden_size] @@ -613,15 +609,15 @@ struct FinalLayer : public GGMLBlock { auto linear = std::dynamic_pointer_cast(blocks["linear"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] - m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] - m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] + m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] + m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] int64_t offset = m->nb[1] * m->ne[1]; - auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + auto scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] - x = modulate(ctx, norm_final->forward(ctx, x), shift, scale); + x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale); x = linear->forward(ctx, x); return x; @@ -791,8 +787,7 @@ struct MMDiT : public GGMLBlock { return x; } - struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward_core_with_concat(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* c_mod, struct ggml_tensor* context, @@ -811,7 +806,7 @@ struct MMDiT : public GGMLBlock { auto block = std::dynamic_pointer_cast(blocks["joint_blocks." + std::to_string(i)]); - auto context_x = block->forward(ctx, backend, context, x, c_mod); + auto context_x = block->forward(ctx, context, x, c_mod); context = context_x.first; x = context_x.second; } @@ -821,8 +816,7 @@ struct MMDiT : public GGMLBlock { return x; } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* t, struct ggml_tensor* y = nullptr, @@ -840,16 +834,16 @@ struct MMDiT : public GGMLBlock { int64_t w = x->ne[0]; int64_t h = x->ne[1]; - auto patch_embed = x_embedder->forward(ctx, x); // [N, H*W, hidden_size] - auto pos_embed = cropped_pos_embed(ctx, h, w); // [1, H*W, hidden_size] - x = ggml_add(ctx, patch_embed, pos_embed); // [N, H*W, hidden_size] + auto patch_embed = x_embedder->forward(ctx, x); // [N, H*W, hidden_size] + auto pos_embed = cropped_pos_embed(ctx->ggml_ctx, h, w); // [1, H*W, hidden_size] + x = ggml_add(ctx->ggml_ctx, patch_embed, pos_embed); // [N, H*W, hidden_size] auto c = t_embedder->forward(ctx, t); // [N, hidden_size] if (y != nullptr && adm_in_channels != -1) { auto y_embedder = std::dynamic_pointer_cast(blocks["y_embedder"]); y = y_embedder->forward(ctx, y); // [N, hidden_size] - c = ggml_add(ctx, c, y); + c = ggml_add(ctx->ggml_ctx, c, y); } if (context != nullptr) { @@ -858,9 +852,9 @@ struct MMDiT : public GGMLBlock { context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536] } - x = forward_core_with_concat(ctx, backend, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) + x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) - x = unpatchify(ctx, x, h, w); // [N, C, H, W] + x = unpatchify(ctx->ggml_ctx, x, h, w); // [N, C, H, W] return x; } @@ -897,8 +891,8 @@ struct MMDiTRunner : public GGMLRunner { y = to_backend(y); timesteps = to_backend(timesteps); - struct ggml_tensor* out = mmdit.forward(compute_ctx, - runtime_backend, + auto runner_ctx = get_context(); + struct ggml_tensor* out = mmdit.forward(&runner_ctx, x, timesteps, y, diff --git a/pmid.hpp b/pmid.hpp index 3d737bc02..ea7c3989d 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -21,7 +21,7 @@ struct FuseBlock : public GGMLBlock { blocks["layernorm"] = std::shared_ptr(new LayerNorm(in_dim)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [N, channels, h, w] auto fc1 = std::dynamic_pointer_cast(blocks["fc1"]); @@ -33,11 +33,11 @@ struct FuseBlock : public GGMLBlock { x = layer_norm->forward(ctx, x); // x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b); x = fc1->forward(ctx, x); - x = ggml_gelu_inplace(ctx, x); + x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = fc2->forward(ctx, x); // x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b); if (use_residue) - x = ggml_add(ctx, x, r); + x = ggml_add(ctx->ggml_ctx, x, r); return x; } }; @@ -54,7 +54,7 @@ struct PMFeedForward : public GGMLBlock { blocks["1"] = std::shared_ptr(new Mlp(dim, inner_dim, dim, false)); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto norm = std::dynamic_pointer_cast(blocks["0"]); auto ff = std::dynamic_pointer_cast(blocks["1"]); @@ -100,7 +100,7 @@ struct PerceiverAttention : public GGMLBlock { ggml_cont(ctx, tli)}; } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* latents) { // x (torch.Tensor): image features @@ -118,33 +118,33 @@ struct PerceiverAttention : public GGMLBlock { auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); auto q = to_q->forward(ctx, latents); - auto kv_input = ggml_concat(ctx, x, latents, 1); + auto kv_input = ggml_concat(ctx->ggml_ctx, x, latents, 1); auto to_kv = std::dynamic_pointer_cast(blocks["to_kv"]); auto kv = to_kv->forward(ctx, kv_input); - auto k = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, 0); - auto v = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, kv->nb[0] * (kv->ne[0] / 2)); - k = ggml_cont(ctx, k); - v = ggml_cont(ctx, v); - q = reshape_tensor(ctx, q, heads); - k = reshape_tensor(ctx, k, heads); - v = reshape_tensor(ctx, v, heads); + auto k = ggml_view_4d(ctx->ggml_ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, 0); + auto v = ggml_view_4d(ctx->ggml_ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, kv->nb[0] * (kv->ne[0] / 2)); + k = ggml_cont(ctx->ggml_ctx, k); + v = ggml_cont(ctx->ggml_ctx, v); + q = reshape_tensor(ctx->ggml_ctx, q, heads); + k = reshape_tensor(ctx->ggml_ctx, k, heads); + v = reshape_tensor(ctx->ggml_ctx, v, heads); scale = 1.f / sqrt(sqrt((float)dim_head)); - k = ggml_scale_inplace(ctx, k, scale); - q = ggml_scale_inplace(ctx, q, scale); + k = ggml_scale_inplace(ctx->ggml_ctx, k, scale); + q = ggml_scale_inplace(ctx->ggml_ctx, q, scale); // auto weight = ggml_mul_mat(ctx, q, k); - auto weight = ggml_mul_mat(ctx, k, q); // NOTE order of mul is opposite to pytorch + auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch // GGML's softmax() is equivalent to pytorch's softmax(x, dim=-1) // in this case, dimension along which Softmax will be computed is the last dim // in torch and the first dim in GGML, consistent with the convention that pytorch's // last dimension (varying most rapidly) corresponds to GGML's first (varying most rapidly). // weight = ggml_soft_max(ctx, weight); - weight = ggml_soft_max_inplace(ctx, weight); - v = ggml_cont(ctx, ggml_transpose(ctx, v)); + weight = ggml_soft_max_inplace(ctx->ggml_ctx, weight); + v = ggml_cont(ctx->ggml_ctx, ggml_transpose(ctx->ggml_ctx, v)); // auto out = ggml_mul_mat(ctx, weight, v); - auto out = ggml_mul_mat(ctx, v, weight); // NOTE order of mul is opposite to pytorch - out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); - out = ggml_reshape_3d(ctx, out, ne[0], ne[1], ggml_nelements(out) / (ne[0] * ne[1])); + auto out = ggml_mul_mat(ctx->ggml_ctx, v, weight); // NOTE order of mul is opposite to pytorch + out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); + out = ggml_reshape_3d(ctx->ggml_ctx, out, ne[0], ne[1], ggml_nelements(out) / (ne[0] * ne[1])); auto to_out = std::dynamic_pointer_cast(blocks["to_out"]); out = to_out->forward(ctx, out); return out; @@ -176,7 +176,7 @@ struct FacePerceiverResampler : public GGMLBlock { } } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* latents, struct ggml_tensor* x) { // x: [N, channels, h, w] @@ -191,9 +191,9 @@ struct FacePerceiverResampler : public GGMLBlock { name = "layers." + std::to_string(i) + ".1"; auto ff = std::dynamic_pointer_cast(blocks[name]); auto t = attn->forward(ctx, x, latents); - latents = ggml_add(ctx, t, latents); + latents = ggml_add(ctx->ggml_ctx, t, latents); t = ff->forward(ctx, latents); - latents = ggml_add(ctx, t, latents); + latents = ggml_add(ctx->ggml_ctx, t, latents); } latents = proj_out->forward(ctx, latents); latents = norm_out->forward(ctx, latents); @@ -225,7 +225,7 @@ struct QFormerPerceiver : public GGMLBlock { 4)); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* last_hidden_state) { // x: [N, channels, h, w] @@ -235,11 +235,11 @@ struct QFormerPerceiver : public GGMLBlock { x = token_proj->forward(ctx, x); int64_t nel = ggml_nelements(x); - x = ggml_reshape_3d(ctx, x, cross_attention_dim, num_tokens, nel / (cross_attention_dim * num_tokens)); + x = ggml_reshape_3d(ctx->ggml_ctx, x, cross_attention_dim, num_tokens, nel / (cross_attention_dim * num_tokens)); x = token_norm->forward(ctx, x); struct ggml_tensor* out = perceiver_resampler->forward(ctx, x, last_hidden_state); if (use_residul) - out = ggml_add(ctx, x, out); + out = ggml_add(ctx->ggml_ctx, x, out); return out; } }; @@ -256,24 +256,24 @@ struct FuseModule : public GGMLBlock { blocks["layer_norm"] = std::shared_ptr(new LayerNorm(embed_dim)); } - struct ggml_tensor* fuse_fn(struct ggml_context* ctx, + struct ggml_tensor* fuse_fn(GGMLRunnerContext* ctx, struct ggml_tensor* prompt_embeds, struct ggml_tensor* id_embeds) { auto mlp1 = std::dynamic_pointer_cast(blocks["mlp1"]); auto mlp2 = std::dynamic_pointer_cast(blocks["mlp2"]); auto layer_norm = std::dynamic_pointer_cast(blocks["layer_norm"]); - auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds, id_embeds, 0); + auto stacked_id_embeds = ggml_concat(ctx->ggml_ctx, prompt_embeds, id_embeds, 0); stacked_id_embeds = mlp1->forward(ctx, stacked_id_embeds); - stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds); + stacked_id_embeds = ggml_add(ctx->ggml_ctx, stacked_id_embeds, prompt_embeds); stacked_id_embeds = mlp2->forward(ctx, stacked_id_embeds); stacked_id_embeds = layer_norm->forward(ctx, stacked_id_embeds); return stacked_id_embeds; } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* prompt_embeds, struct ggml_tensor* id_embeds, struct ggml_tensor* class_tokens_mask, @@ -286,25 +286,25 @@ struct FuseModule : public GGMLBlock { // # slice out the image token embeddings ggml_set_name(class_tokens_mask_pos, "class_tokens_mask_pos"); ggml_set_name(prompt_embeds, "prompt_embeds"); - struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos); + struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx->ggml_ctx, prompt_embeds, class_tokens_mask_pos); ggml_set_name(image_token_embeds, "image_token_embeds"); - valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0], + valid_id_embeds = ggml_reshape_2d(ctx->ggml_ctx, valid_id_embeds, valid_id_embeds->ne[0], ggml_nelements(valid_id_embeds) / valid_id_embeds->ne[0]); struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds); if (left && right) { - stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1); - stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1); + stacked_id_embeds = ggml_concat(ctx->ggml_ctx, left, stacked_id_embeds, 1); + stacked_id_embeds = ggml_concat(ctx->ggml_ctx, stacked_id_embeds, right, 1); } else if (left) { - stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1); + stacked_id_embeds = ggml_concat(ctx->ggml_ctx, left, stacked_id_embeds, 1); } else if (right) { - stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1); + stacked_id_embeds = ggml_concat(ctx->ggml_ctx, stacked_id_embeds, right, 1); } - class_tokens_mask = ggml_cont(ctx, ggml_transpose(ctx, class_tokens_mask)); - class_tokens_mask = ggml_repeat(ctx, class_tokens_mask, prompt_embeds); - prompt_embeds = ggml_mul(ctx, prompt_embeds, class_tokens_mask); - struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx, prompt_embeds, stacked_id_embeds); + class_tokens_mask = ggml_cont(ctx->ggml_ctx, ggml_transpose(ctx->ggml_ctx, class_tokens_mask)); + class_tokens_mask = ggml_repeat(ctx->ggml_ctx, class_tokens_mask, prompt_embeds); + prompt_embeds = ggml_mul(ctx->ggml_ctx, prompt_embeds, class_tokens_mask); + struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx->ggml_ctx, prompt_embeds, stacked_id_embeds); ggml_set_name(updated_prompt_embeds, "updated_prompt_embeds"); return updated_prompt_embeds; } @@ -317,8 +317,7 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection { blocks["fuse_module"] = std::shared_ptr(new FuseModule(2048)); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* id_pixel_values, struct ggml_tensor* prompt_embeds, struct ggml_tensor* class_tokens_mask, @@ -331,15 +330,15 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection { auto visual_projection_2 = std::dynamic_pointer_cast(blocks["visual_projection_2"]); auto fuse_module = std::dynamic_pointer_cast(blocks["fuse_module"]); - struct ggml_tensor* shared_id_embeds = vision_model->forward(ctx, backend, id_pixel_values); // [N, hidden_size] - struct ggml_tensor* id_embeds = visual_projection->forward(ctx, shared_id_embeds); // [N, proj_dim(768)] - struct ggml_tensor* id_embeds_2 = visual_projection_2->forward(ctx, shared_id_embeds); // [N, 1280] + struct ggml_tensor* shared_id_embeds = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size] + struct ggml_tensor* id_embeds = visual_projection->forward(ctx, shared_id_embeds); // [N, proj_dim(768)] + struct ggml_tensor* id_embeds_2 = visual_projection_2->forward(ctx, shared_id_embeds); // [N, 1280] - id_embeds = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3)); - id_embeds_2 = ggml_cont(ctx, ggml_permute(ctx, id_embeds_2, 2, 0, 1, 3)); + id_embeds = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, id_embeds, 2, 0, 1, 3)); + id_embeds_2 = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, id_embeds_2, 2, 0, 1, 3)); - id_embeds = ggml_concat(ctx, id_embeds, id_embeds_2, 2); // [batch_size, seq_length, 1, 2048] check whether concat at dim 2 is right - id_embeds = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 1, 2, 0, 3)); + id_embeds = ggml_concat(ctx->ggml_ctx, id_embeds, id_embeds_2, 2); // [batch_size, seq_length, 1, 2048] check whether concat at dim 2 is right + id_embeds = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, id_embeds, 1, 2, 0, 3)); struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx, prompt_embeds, @@ -366,8 +365,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo num_tokens)); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* id_pixel_values, struct ggml_tensor* prompt_embeds, struct ggml_tensor* class_tokens_mask, @@ -381,7 +379,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo auto qformer_perceiver = std::dynamic_pointer_cast(blocks["qformer_perceiver"]); // struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size] - struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, backend, id_pixel_values, false); // [N, hidden_size] + struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size] id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state); struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx, @@ -458,7 +456,7 @@ struct PhotoMakerIDEncoder : public GGMLRunner { zeros_right.clear(); zeros_right_16.clear(); - ggml_context* ctx0 = compute_ctx; + auto runner_ctx = get_context(); struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); @@ -466,7 +464,7 @@ struct PhotoMakerIDEncoder : public GGMLRunner { int64_t seq_length = prompt_embeds->ne[1]; ggml_type type = GGML_TYPE_F32; - struct ggml_tensor* class_tokens_mask_d = ggml_new_tensor_1d(ctx0, type, class_tokens_mask.size()); + struct ggml_tensor* class_tokens_mask_d = ggml_new_tensor_1d(runner_ctx.ggml_ctx, type, class_tokens_mask.size()); struct ggml_tensor* id_pixel_values_d = to_backend(id_pixel_values); struct ggml_tensor* prompt_embeds_d = to_backend(prompt_embeds); @@ -488,16 +486,16 @@ struct PhotoMakerIDEncoder : public GGMLRunner { } // printf("\n"); if (ctmpos[0] > 0) { - // left = ggml_new_tensor_3d(ctx0, type, hidden_size, 1, ctmpos[0]); - left = ggml_new_tensor_3d(ctx0, type, hidden_size, ctmpos[0], 1); + // left = ggml_new_tensor_3d(runner_ctx.ggml_ctx, type, hidden_size, 1, ctmpos[0]); + left = ggml_new_tensor_3d(runner_ctx.ggml_ctx, type, hidden_size, ctmpos[0], 1); } if (ctmpos[ctmpos.size() - 1] < seq_length - 1) { - // right = ggml_new_tensor_3d(ctx0, type, + // right = ggml_new_tensor_3d(runner_ctx.ggml_ctx, type, // hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1); - right = ggml_new_tensor_3d(ctx0, type, + right = ggml_new_tensor_3d(runner_ctx.ggml_ctx, type, hidden_size, seq_length - ctmpos[ctmpos.size() - 1] - 1, 1); } - struct ggml_tensor* class_tokens_mask_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctmpos.size()); + struct ggml_tensor* class_tokens_mask_pos = ggml_new_tensor_1d(runner_ctx.ggml_ctx, GGML_TYPE_I32, ctmpos.size()); { if (type == GGML_TYPE_F16) @@ -530,16 +528,14 @@ struct PhotoMakerIDEncoder : public GGMLRunner { } struct ggml_tensor* updated_prompt_embeds = nullptr; if (pm_version == PM_VERSION_1) - updated_prompt_embeds = id_encoder.forward(ctx0, - runtime_backend, + updated_prompt_embeds = id_encoder.forward(&runner_ctx, id_pixel_values_d, prompt_embeds_d, class_tokens_mask_d, class_tokens_mask_pos, left, right); else if (pm_version == PM_VERSION_2) - updated_prompt_embeds = id_encoder2.forward(ctx0, - runtime_backend, + updated_prompt_embeds = id_encoder2.forward(&runner_ctx, id_pixel_values_d, prompt_embeds_d, class_tokens_mask_d, diff --git a/qwen_image.hpp b/qwen_image.hpp index 94248cea9..e442cfda4 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -27,18 +27,18 @@ namespace Qwen { blocks["linear_2"] = std::shared_ptr(new Linear(time_embed_dim, out_dim, sample_proj_bias)); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* sample, struct ggml_tensor* condition = nullptr) { if (condition != nullptr) { auto cond_proj = std::dynamic_pointer_cast(blocks["cond_proj"]); - sample = ggml_add(ctx, sample, cond_proj->forward(ctx, condition)); + sample = ggml_add(ctx->ggml_ctx, sample, cond_proj->forward(ctx, condition)); } auto linear_1 = std::dynamic_pointer_cast(blocks["linear_1"]); auto linear_2 = std::dynamic_pointer_cast(blocks["linear_2"]); sample = linear_1->forward(ctx, sample); - sample = ggml_silu_inplace(ctx, sample); + sample = ggml_silu_inplace(ctx->ggml_ctx, sample); sample = linear_2->forward(ctx, sample); return sample; } @@ -50,13 +50,13 @@ namespace Qwen { blocks["timestep_embedder"] = std::shared_ptr(new TimestepEmbedding(256, embedding_dim)); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* timesteps) { // timesteps: [N,] // return: [N, embedding_dim] auto timestep_embedder = std::dynamic_pointer_cast(blocks["timestep_embedder"]); - auto timesteps_proj = ggml_ext_timestep_embedding(ctx, timesteps, 256, 10000, 1.f); + auto timesteps_proj = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, 256, 10000, 1.f); auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj); return timesteps_emb; } @@ -105,8 +105,7 @@ namespace Qwen { blocks["to_add_out"] = std::shared_ptr(new Linear(inner_dim, out_context_dim, out_bias, false, false, scale)); } - std::pair forward(struct ggml_context* ctx, - ggml_backend_t backend, + std::pair forward(GGMLRunnerContext* ctx, struct ggml_tensor* img, struct ggml_tensor* txt, struct ggml_tensor* pe, @@ -138,49 +137,49 @@ namespace Qwen { auto img_q = to_q->forward(ctx, img); int64_t num_heads = img_q->ne[0] / dim_head; - img_q = ggml_reshape_4d(ctx, img_q, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head] + img_q = ggml_reshape_4d(ctx->ggml_ctx, img_q, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head] auto img_k = to_k->forward(ctx, img); - img_k = ggml_reshape_4d(ctx, img_k, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head] + img_k = ggml_reshape_4d(ctx->ggml_ctx, img_k, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head] auto img_v = to_v->forward(ctx, img); - img_v = ggml_reshape_4d(ctx, img_v, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head] + img_v = ggml_reshape_4d(ctx->ggml_ctx, img_v, dim_head, num_heads, n_img_token, N); // [N, n_img_token, n_head, d_head] img_q = norm_q->forward(ctx, img_q); img_k = norm_k->forward(ctx, img_k); auto txt_q = add_q_proj->forward(ctx, txt); - txt_q = ggml_reshape_4d(ctx, txt_q, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head] + txt_q = ggml_reshape_4d(ctx->ggml_ctx, txt_q, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head] auto txt_k = add_k_proj->forward(ctx, txt); - txt_k = ggml_reshape_4d(ctx, txt_k, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head] + txt_k = ggml_reshape_4d(ctx->ggml_ctx, txt_k, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head] auto txt_v = add_v_proj->forward(ctx, txt); - txt_v = ggml_reshape_4d(ctx, txt_v, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head] + txt_v = ggml_reshape_4d(ctx->ggml_ctx, txt_v, dim_head, num_heads, n_txt_token, N); // [N, n_txt_token, n_head, d_head] txt_q = norm_added_q->forward(ctx, txt_q); txt_k = norm_added_k->forward(ctx, txt_k); - auto q = ggml_concat(ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + auto q = ggml_concat(ctx->ggml_ctx, txt_q, img_q, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] + auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] - attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] - auto txt_attn_out = ggml_view_3d(ctx, + auto attn = Rope::attention(ctx, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] + attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], attn->ne[1], txt->ne[1], attn->nb[1], attn->nb[2], - 0); // [n_txt_token, N, hidden_size] - txt_attn_out = ggml_cont(ctx, ggml_permute(ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size] - auto img_attn_out = ggml_view_3d(ctx, + 0); // [n_txt_token, N, hidden_size] + txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size] + auto img_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], attn->ne[1], img->ne[1], attn->nb[1], attn->nb[2], - attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] - img_attn_out = ggml_cont(ctx, ggml_permute(ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] + img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] img_attn_out = to_out_0->forward(ctx, img_attn_out); txt_attn_out = to_add_out->forward(ctx, txt_attn_out); @@ -221,8 +220,7 @@ namespace Qwen { flash_attn)); } - virtual std::pair forward(struct ggml_context* ctx, - ggml_backend_t backend, + virtual std::pair forward(GGMLRunnerContext* ctx, struct ggml_tensor* img, struct ggml_tensor* txt, struct ggml_tensor* t_emb, @@ -244,40 +242,40 @@ namespace Qwen { auto attn = std::dynamic_pointer_cast(blocks["attn"]); - auto img_mod_params = ggml_silu(ctx, t_emb); + auto img_mod_params = ggml_silu(ctx->ggml_ctx, t_emb); img_mod_params = img_mod_1->forward(ctx, img_mod_params); - auto img_mod_param_vec = ggml_ext_chunk(ctx, img_mod_params, 6, 0); + auto img_mod_param_vec = ggml_ext_chunk(ctx->ggml_ctx, img_mod_params, 6, 0); - auto txt_mod_params = ggml_silu(ctx, t_emb); + auto txt_mod_params = ggml_silu(ctx->ggml_ctx, t_emb); txt_mod_params = txt_mod_1->forward(ctx, txt_mod_params); - auto txt_mod_param_vec = ggml_ext_chunk(ctx, txt_mod_params, 6, 0); + auto txt_mod_param_vec = ggml_ext_chunk(ctx->ggml_ctx, txt_mod_params, 6, 0); auto img_normed = img_norm1->forward(ctx, img); - auto img_modulated = Flux::modulate(ctx, img_normed, img_mod_param_vec[0], img_mod_param_vec[1]); + auto img_modulated = Flux::modulate(ctx->ggml_ctx, img_normed, img_mod_param_vec[0], img_mod_param_vec[1]); auto img_gate1 = img_mod_param_vec[2]; auto txt_normed = txt_norm1->forward(ctx, txt); - auto txt_modulated = Flux::modulate(ctx, txt_normed, txt_mod_param_vec[0], txt_mod_param_vec[1]); + auto txt_modulated = Flux::modulate(ctx->ggml_ctx, txt_normed, txt_mod_param_vec[0], txt_mod_param_vec[1]); auto txt_gate1 = txt_mod_param_vec[2]; - auto [img_attn_output, txt_attn_output] = attn->forward(ctx, backend, img_modulated, txt_modulated, pe); + auto [img_attn_output, txt_attn_output] = attn->forward(ctx, img_modulated, txt_modulated, pe); - img = ggml_add(ctx, img, ggml_mul(ctx, img_attn_output, img_gate1)); - txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn_output, txt_gate1)); + img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn_output, img_gate1)); + txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn_output, txt_gate1)); auto img_normed2 = img_norm2->forward(ctx, img); - auto img_modulated2 = Flux::modulate(ctx, img_normed2, img_mod_param_vec[3], img_mod_param_vec[4]); + auto img_modulated2 = Flux::modulate(ctx->ggml_ctx, img_normed2, img_mod_param_vec[3], img_mod_param_vec[4]); auto img_gate2 = img_mod_param_vec[5]; auto txt_normed2 = txt_norm2->forward(ctx, txt); - auto txt_modulated2 = Flux::modulate(ctx, txt_normed2, txt_mod_param_vec[3], txt_mod_param_vec[4]); + auto txt_modulated2 = Flux::modulate(ctx->ggml_ctx, txt_normed2, txt_mod_param_vec[3], txt_mod_param_vec[4]); auto txt_gate2 = txt_mod_param_vec[5]; auto img_mlp_out = img_mlp->forward(ctx, img_modulated2); auto txt_mlp_out = txt_mlp->forward(ctx, txt_modulated2); - img = ggml_add(ctx, img, ggml_mul(ctx, img_mlp_out, img_gate2)); - txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_gate2)); + img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_mlp_out, img_gate2)); + txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_mlp_out, txt_gate2)); return {img, txt}; } @@ -294,7 +292,7 @@ namespace Qwen { blocks["linear"] = std::shared_ptr(new Linear(conditioning_embedding_dim, embedding_dim * 2, bias)); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { // x: [N, n_token, hidden_size] @@ -304,13 +302,13 @@ namespace Qwen { auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto linear = std::dynamic_pointer_cast(blocks["linear"]); - auto emb = linear->forward(ctx, ggml_silu(ctx, c)); - auto mods = ggml_ext_chunk(ctx, emb, 2, 0); + auto emb = linear->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); + auto mods = ggml_ext_chunk(ctx->ggml_ctx, emb, 2, 0); auto scale = mods[0]; auto shift = mods[1]; x = norm->forward(ctx, x); - x = Flux::modulate(ctx, x, shift, scale); + x = Flux::modulate(ctx->ggml_ctx, x, shift, scale); return x; } @@ -421,8 +419,7 @@ namespace Qwen { return x; } - struct ggml_tensor* forward_orig(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, @@ -442,7 +439,7 @@ namespace Qwen { for (int i = 0; i < params.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["transformer_blocks." + std::to_string(i)]); - auto result = block->forward(ctx, backend, img, txt, t_emb, pe); + auto result = block->forward(ctx, img, txt, t_emb, pe); img = result.first; txt = result.second; } @@ -453,8 +450,7 @@ namespace Qwen { return img; } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, @@ -472,32 +468,32 @@ namespace Qwen { int64_t C = x->ne[2]; int64_t N = x->ne[3]; - auto img = process_img(ctx, x); + auto img = process_img(ctx->ggml_ctx, x); uint64_t img_tokens = img->ne[1]; if (ref_latents.size() > 0) { for (ggml_tensor* ref : ref_latents) { - ref = process_img(ctx, ref); - img = ggml_concat(ctx, img, ref, 1); + ref = process_img(ctx->ggml_ctx, ref); + img = ggml_concat(ctx->ggml_ctx, img, ref, 1); } } int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size); int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size); - auto out = forward_orig(ctx, backend, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C] + auto out = forward_orig(ctx, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C] if (out->ne[1] > img_tokens) { - out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] - out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); - out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] + out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] + out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); + out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] } - out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] + out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] // slice - out = ggml_ext_slice(ctx, out, 1, 0, H); // [N, C, H, W + pad_w] - out = ggml_ext_slice(ctx, out, 0, 0, W); // [N, C, H, W] + out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w] + out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W] return out; } @@ -582,8 +578,9 @@ namespace Qwen { // pe->data = nullptr; set_backend_tensor_data(pe, pe_vec.data()); - struct ggml_tensor* out = qwen_image.forward(compute_ctx, - runtime_backend, + auto runner_ctx = get_context(); + + struct ggml_tensor* out = qwen_image.forward(&runner_ctx, x, timesteps, context, diff --git a/qwenvl.hpp b/qwenvl.hpp index c75459ee4..d430742f6 100644 --- a/qwenvl.hpp +++ b/qwenvl.hpp @@ -349,15 +349,15 @@ namespace Qwen { blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, bias)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [N, n_token, hidden_size] auto gate_proj = std::dynamic_pointer_cast(blocks["gate_proj"]); auto up_proj = std::dynamic_pointer_cast(blocks["up_proj"]); auto down_proj = std::dynamic_pointer_cast(blocks["down_proj"]); auto h = gate_proj->forward(ctx, x); - h = ggml_silu_inplace(ctx, h); - h = ggml_mul_inplace(ctx, h, up_proj->forward(ctx, x)); + h = ggml_silu_inplace(ctx->ggml_ctx, h); + h = ggml_mul_inplace(ctx->ggml_ctx, h, up_proj->forward(ctx, x)); h = down_proj->forward(ctx, h); return h; } @@ -409,10 +409,10 @@ namespace Qwen { } } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [N*grid_t*grid_h*grid_w, in_channels, temporal_patch_size*patch_size*patch_size] // return: [N*grid_t*grid_h*grid_w, embed_dim] - x = ggml_reshape_4d(ctx, + x = ggml_reshape_4d(ctx->ggml_ctx, x, patch_size, patch_size, @@ -423,22 +423,22 @@ namespace Qwen { auto proj_0 = std::dynamic_pointer_cast(blocks["proj.0"]); auto proj_1 = std::dynamic_pointer_cast(blocks["proj.1"]); - auto x0 = ggml_ext_slice(ctx, x, 2, 0, 1); - x0 = ggml_reshape_4d(ctx, x0, x0->ne[0], x0->ne[1], in_channels, x0->ne[3] / in_channels); + auto x0 = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1); + x0 = ggml_reshape_4d(ctx->ggml_ctx, x0, x0->ne[0], x0->ne[1], in_channels, x0->ne[3] / in_channels); x0 = proj_0->forward(ctx, x0); - auto x1 = ggml_ext_slice(ctx, x, 2, 1, 2); - x1 = ggml_reshape_4d(ctx, x1, x1->ne[0], x1->ne[1], in_channels, x1->ne[3] / in_channels); + auto x1 = ggml_ext_slice(ctx->ggml_ctx, x, 2, 1, 2); + x1 = ggml_reshape_4d(ctx->ggml_ctx, x1, x1->ne[0], x1->ne[1], in_channels, x1->ne[3] / in_channels); x1 = proj_1->forward(ctx, x1); - x = ggml_add(ctx, x0, x1); + x = ggml_add(ctx->ggml_ctx, x0, x1); } else { auto proj = std::dynamic_pointer_cast(blocks["proj"]); x = proj->forward(ctx, x); } - x = ggml_reshape_2d(ctx, x, embed_dim, ggml_nelements(x) / embed_dim); + x = ggml_reshape_2d(ctx->ggml_ctx, x, embed_dim, ggml_nelements(x) / embed_dim); return x; } }; @@ -458,15 +458,15 @@ namespace Qwen { blocks["mlp.2"] = std::shared_ptr(new Linear(hidden_size, dim)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto ln_q = std::dynamic_pointer_cast(blocks["ln_q"]); auto mlp_0 = std::dynamic_pointer_cast(blocks["mlp.0"]); auto mlp_2 = std::dynamic_pointer_cast(blocks["mlp.2"]); x = ln_q->forward(ctx, x); - x = ggml_reshape_2d(ctx, x, hidden_size, ggml_nelements(x) / hidden_size); + x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size); x = mlp_0->forward(ctx, x); - x = ggml_gelu(ctx, x); + x = ggml_gelu(ctx->ggml_ctx, x); x = mlp_2->forward(ctx, x); return x; } @@ -495,8 +495,7 @@ namespace Qwen { blocks["proj"] = std::shared_ptr(new Linear(hidden_size, hidden_size)); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask = nullptr) { @@ -519,14 +518,14 @@ namespace Qwen { } else { auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); auto qkv = qkv_proj->forward(ctx, x); - qkv_vec = split_qkv(ctx, qkv); + qkv_vec = split_qkv(ctx->ggml_ctx, qkv); } - auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] - auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] - auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] + auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] + auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] + auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] - x = Rope::attention(ctx, backend, q, k, v, pe, mask, false, 1.f, false); // [N, n_token, hidden_size] + x = Rope::attention(ctx, q, k, v, pe, mask, false, 1.f, false); // [N, n_token, hidden_size] x = proj->forward(ctx, x); // [N, n_token, hidden_size] return x; @@ -546,8 +545,7 @@ namespace Qwen { blocks["norm2"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask = nullptr) { @@ -559,13 +557,13 @@ namespace Qwen { auto residual = x; x = norm1->forward(ctx, x); - x = attn->forward(ctx, backend, x, pe, mask); - x = ggml_add_inplace(ctx, x, residual); + x = attn->forward(ctx, x, pe, mask); + x = ggml_add_inplace(ctx->ggml_ctx, x, residual); residual = x; x = norm2->forward(ctx, x); x = mlp->forward(ctx, x); - x = ggml_add_inplace(ctx, x, residual); + x = ggml_add_inplace(ctx->ggml_ctx, x, residual); return x; } @@ -607,8 +605,7 @@ namespace Qwen { blocks["merger"] = std::shared_ptr(new Qwen2_5_VLPatchMerger(out_hidden_size, hidden_size, spatial_merge_size)); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* pixel_values, struct ggml_tensor* pe, struct ggml_tensor* window_index, @@ -623,9 +620,9 @@ namespace Qwen { auto x = patch_embed->forward(ctx, pixel_values); - x = ggml_reshape_4d(ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]); - x = ggml_get_rows(ctx, x, window_index); - x = ggml_reshape_4d(ctx, x, x->ne[0] / spatial_merge_size / spatial_merge_size, x->ne[1] * spatial_merge_size * spatial_merge_size, x->ne[2], x->ne[3]); + x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0] * spatial_merge_size * spatial_merge_size, x->ne[1] / spatial_merge_size / spatial_merge_size, x->ne[2], x->ne[3]); + x = ggml_get_rows(ctx->ggml_ctx, x, window_index); + x = ggml_reshape_4d(ctx->ggml_ctx, x, x->ne[0] / spatial_merge_size / spatial_merge_size, x->ne[1] * spatial_merge_size * spatial_merge_size, x->ne[2], x->ne[3]); for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["blocks." + std::to_string(i)]); @@ -634,12 +631,12 @@ namespace Qwen { if (fullatt_block_indexes.find(i) != fullatt_block_indexes.end()) { mask = nullptr; } - x = block->forward(ctx, backend, x, pe, mask); + x = block->forward(ctx, x, pe, mask); } x = merger->forward(ctx, x); - x = ggml_get_rows(ctx, x, window_inverse_index); + x = ggml_get_rows(ctx->ggml_ctx, x, window_inverse_index); return x; } @@ -664,8 +661,7 @@ namespace Qwen { blocks["o_proj"] = std::shared_ptr(new Linear(num_heads * head_dim, hidden_size, false)); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* input_pos) { // x: [N, n_token, hidden_size] @@ -680,21 +676,21 @@ namespace Qwen { auto k = k_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim] auto v = v_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim] - q = ggml_reshape_4d(ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, num_heads, head_dim] - k = ggml_reshape_4d(ctx, k, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim] - v = ggml_reshape_4d(ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim] + q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, num_heads, head_dim] + k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim] + v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim] int sections[4] = {16, 24, 24, 0}; - q = ggml_rope_multi(ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); - k = ggml_rope_multi(ctx, k, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + k = ggml_rope_multi(ctx->ggml_ctx, k, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); - q = ggml_cont(ctx, ggml_ext_torch_permute(ctx, q, 0, 2, 1, 3)); // [N, num_heads, n_token, head_dim] - q = ggml_reshape_3d(ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); // [N*num_heads, n_token, head_dim] + q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 2, 1, 3)); // [N, num_heads, n_token, head_dim] + q = ggml_reshape_3d(ctx->ggml_ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); // [N*num_heads, n_token, head_dim] - k = ggml_cont(ctx, ggml_ext_torch_permute(ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim] - k = ggml_reshape_3d(ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim] + k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim] + k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim] - x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size] x = out_proj->forward(ctx, x); // [N, n_token, hidden_size] return x; @@ -714,8 +710,7 @@ namespace Qwen { blocks["post_attention_layernorm"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* input_pos) { // x: [N, n_token, hidden_size] @@ -726,13 +721,13 @@ namespace Qwen { auto residual = x; x = input_layernorm->forward(ctx, x); - x = self_attn->forward(ctx, backend, x, input_pos); - x = ggml_add_inplace(ctx, x, residual); + x = self_attn->forward(ctx, x, input_pos); + x = ggml_add_inplace(ctx->ggml_ctx, x, residual); residual = x; x = post_attention_layernorm->forward(ctx, x); x = mlp->forward(ctx, x); - x = ggml_add_inplace(ctx, x, residual); + x = ggml_add_inplace(ctx->ggml_ctx, x, residual); return x; } @@ -761,8 +756,7 @@ namespace Qwen { blocks["norm"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* input_pos, std::vector> image_embeds) { @@ -777,7 +771,7 @@ namespace Qwen { if (image_embeds.size() > 0) { GGML_ASSERT(x->ne[2] == 1); // N == 1 - auto raw_x = ggml_cast(ctx, x, image_embeds[0].second->type); + auto raw_x = ggml_cast(ctx->ggml_ctx, x, image_embeds[0].second->type); int64_t txt_token_start = 0; int64_t txt_token_end = 0; @@ -791,23 +785,23 @@ namespace Qwen { } txt_token_end = image_embeds[i].first; - auto txt_embed = ggml_ext_slice(ctx, raw_x, 1, txt_token_start, txt_token_end); + auto txt_embed = ggml_ext_slice(ctx->ggml_ctx, raw_x, 1, txt_token_start, txt_token_end); if (input_embed == nullptr) { input_embed = txt_embed; } else { - input_embed = ggml_concat(ctx, input_embed, txt_embed, 1); + input_embed = ggml_concat(ctx->ggml_ctx, input_embed, txt_embed, 1); } auto image_embed = image_embeds[i].second; - input_embed = ggml_concat(ctx, input_embed, image_embed, 1); + input_embed = ggml_concat(ctx->ggml_ctx, input_embed, image_embed, 1); } txt_token_start = image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1]; txt_token_end = raw_x->ne[1]; - auto final_txt_embed = ggml_ext_slice(ctx, raw_x, 1, txt_token_start, txt_token_end); + auto final_txt_embed = ggml_ext_slice(ctx->ggml_ctx, raw_x, 1, txt_token_start, txt_token_end); - input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1); + input_embed = ggml_concat(ctx->ggml_ctx, input_embed, final_txt_embed, 1); GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]); x = input_embed; @@ -816,7 +810,7 @@ namespace Qwen { for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); - x = block->forward(ctx, backend, x, input_pos); + x = block->forward(ctx, x, input_pos); } x = norm->forward(ctx, x); @@ -880,20 +874,18 @@ namespace Qwen { } } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* input_pos, std::vector> image_embeds) { // input_ids: [N, n_token] auto model = std::dynamic_pointer_cast(blocks["model"]); - auto x = model->forward(ctx, backend, input_ids, input_pos, image_embeds); + auto x = model->forward(ctx, input_ids, input_pos, image_embeds); return x; } - struct ggml_tensor* vision_forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* vision_forward(GGMLRunnerContext* ctx, struct ggml_tensor* pixel_values, struct ggml_tensor* pe, struct ggml_tensor* window_index, @@ -901,7 +893,7 @@ namespace Qwen { struct ggml_tensor* window_mask) { GGML_ASSERT(enable_vision); auto vision_model = std::dynamic_pointer_cast(blocks["visual"]); - return vision_model->forward(ctx, backend, pixel_values, pe, window_index, window_inverse_index, window_mask); + return vision_model->forward(ctx, pixel_values, pe, window_index, window_inverse_index, window_mask); } }; @@ -959,23 +951,21 @@ namespace Qwen { model.get_param_tensors(tensors, prefix); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* input_pos, std::vector> image_embeds) { - auto hidden_states = model.forward(ctx, backend, input_ids, input_pos, image_embeds); // [N, n_token, hidden_size] + auto hidden_states = model.forward(ctx, input_ids, input_pos, image_embeds); // [N, n_token, hidden_size] return hidden_states; } - struct ggml_tensor* vision_forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* vision_forward(GGMLRunnerContext* ctx, struct ggml_tensor* pixel_values, struct ggml_tensor* input_pos, struct ggml_tensor* window_index, struct ggml_tensor* window_inverse_index, struct ggml_tensor* window_mask) { - auto hidden_states = model.vision_forward(ctx, backend, pixel_values, input_pos, window_index, window_inverse_index, window_mask); + auto hidden_states = model.vision_forward(ctx, pixel_values, input_pos, window_index, window_inverse_index, window_mask); return hidden_states; } @@ -1002,7 +992,9 @@ namespace Qwen { n_tokens * 4); set_backend_tensor_data(input_pos, input_pos_vec.data()); - struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, input_pos, image_embeds); + auto runner_ctx = get_context(); + + struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, image_embeds); ggml_build_forward_expand(gf, hidden_states); @@ -1167,8 +1159,8 @@ namespace Qwen { // pe->data = nullptr; set_backend_tensor_data(pe, pe_vec.data()); - struct ggml_tensor* hidden_states = vision_forward(compute_ctx, - runtime_backend, + auto runnter_ctx = get_context(); + struct ggml_tensor* hidden_states = vision_forward(&runnter_ctx, pixel_values, pe, window_index, diff --git a/rope.hpp b/rope.hpp index b738dc529..05f04d0d5 100644 --- a/rope.hpp +++ b/rope.hpp @@ -386,8 +386,7 @@ namespace Rope { return x_out; } - __STATIC_INLINE__ struct ggml_tensor* attention(struct ggml_context* ctx, - ggml_backend_t backend, + __STATIC_INLINE__ struct ggml_tensor* attention(GGMLRunnerContext* ctx, struct ggml_tensor* q, struct ggml_tensor* k, struct ggml_tensor* v, @@ -399,10 +398,10 @@ namespace Rope { // q,k,v: [N, L, n_head, d_head] // pe: [L, d_head/2, 2, 2] // return: [N, L, n_head*d_head] - q = apply_rope(ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head] - k = apply_rope(ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head] + q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head] + k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head] - auto x = ggml_ext_attention_ext(ctx, backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head] + auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head] return x; } }; // namespace Rope diff --git a/t5.hpp b/t5.hpp index a8dce60a9..1f6341f8d 100644 --- a/t5.hpp +++ b/t5.hpp @@ -472,10 +472,10 @@ class T5LayerNorm : public UnaryBlock { : hidden_size(hidden_size), eps(eps) {} - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { struct ggml_tensor* w = params["weight"]; - x = ggml_rms_norm(ctx, x, eps); - x = ggml_mul(ctx, x, w); + x = ggml_rms_norm(ctx->ggml_ctx, x, eps); + x = ggml_mul(ctx->ggml_ctx, x, w); return x; } }; @@ -487,13 +487,13 @@ struct T5DenseActDense : public UnaryBlock { blocks["wo"] = std::shared_ptr(new Linear(ff_dim, model_dim, false)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [N, n_token, model_dim] auto wi = std::dynamic_pointer_cast(blocks["wi"]); auto wo = std::dynamic_pointer_cast(blocks["wo"]); x = wi->forward(ctx, x); - x = ggml_relu_inplace(ctx, x); + x = ggml_relu_inplace(ctx->ggml_ctx, x); x = wo->forward(ctx, x); return x; } @@ -509,15 +509,15 @@ struct T5DenseGatedActDense : public UnaryBlock { blocks["wo"] = std::shared_ptr(new Linear(ff_dim, model_dim, false, false, false, scale)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [N, n_token, model_dim] auto wi_0 = std::dynamic_pointer_cast(blocks["wi_0"]); auto wi_1 = std::dynamic_pointer_cast(blocks["wi_1"]); auto wo = std::dynamic_pointer_cast(blocks["wo"]); - auto hidden_gelu = ggml_gelu_inplace(ctx, wi_0->forward(ctx, x)); + auto hidden_gelu = ggml_gelu_inplace(ctx->ggml_ctx, wi_0->forward(ctx, x)); auto hidden_linear = wi_1->forward(ctx, x); - x = ggml_mul_inplace(ctx, hidden_gelu, hidden_linear); + x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear); x = wo->forward(ctx, x); return x; } @@ -530,14 +530,14 @@ struct T5LayerFF : public UnaryBlock { blocks["layer_norm"] = std::shared_ptr(new T5LayerNorm(model_dim)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [N, n_token, model_dim] auto DenseReluDense = std::dynamic_pointer_cast(blocks["DenseReluDense"]); auto layer_norm = std::dynamic_pointer_cast(blocks["layer_norm"]); auto forwarded_states = layer_norm->forward(ctx, x); forwarded_states = DenseReluDense->forward(ctx, forwarded_states); - x = ggml_add_inplace(ctx, forwarded_states, x); + x = ggml_add_inplace(ctx->ggml_ctx, forwarded_states, x); return x; } }; @@ -569,18 +569,17 @@ class T5Attention : public GGMLBlock { } } - struct ggml_tensor* compute_bias(struct ggml_context* ctx, + struct ggml_tensor* compute_bias(GGMLRunnerContext* ctx, struct ggml_tensor* relative_position_bucket) { auto relative_attention_bias = std::dynamic_pointer_cast(blocks["relative_attention_bias"]); - auto values = relative_attention_bias->forward(ctx, relative_position_bucket); // shape (query_length, key_length, num_heads) - values = ggml_cont(ctx, ggml_permute(ctx, values, 2, 0, 1, 3)); // shape (1, num_heads, query_length, key_length) + auto values = relative_attention_bias->forward(ctx, relative_position_bucket); // shape (query_length, key_length, num_heads) + values = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, values, 2, 0, 1, 3)); // shape (1, num_heads, query_length, key_length) return values; } // x: [N, n_token, model_dim] - std::pair forward(struct ggml_context* ctx, - ggml_backend_t backend, + std::pair forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* past_bias = nullptr, struct ggml_tensor* mask = nullptr, @@ -602,16 +601,16 @@ class T5Attention : public GGMLBlock { } if (past_bias != nullptr) { if (mask != nullptr) { - mask = ggml_repeat(ctx, mask, past_bias); - mask = ggml_add(ctx, mask, past_bias); + mask = ggml_repeat(ctx->ggml_ctx, mask, past_bias); + mask = ggml_add(ctx->ggml_ctx, mask, past_bias); } else { mask = past_bias; } } - k = ggml_scale_inplace(ctx, k, sqrt(d_head)); + k = ggml_scale_inplace(ctx->ggml_ctx, k, sqrt(d_head)); - x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head] x = out_proj->forward(ctx, x); // [N, n_token, model_dim] return {x, past_bias}; @@ -629,8 +628,7 @@ struct T5LayerSelfAttention : public GGMLBlock { blocks["layer_norm"] = std::shared_ptr(new T5LayerNorm(model_dim)); } - std::pair forward(struct ggml_context* ctx, - ggml_backend_t backend, + std::pair forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* past_bias = nullptr, struct ggml_tensor* mask = nullptr, @@ -640,11 +638,11 @@ struct T5LayerSelfAttention : public GGMLBlock { auto layer_norm = std::dynamic_pointer_cast(blocks["layer_norm"]); auto normed_hidden_state = layer_norm->forward(ctx, x); - auto ret = SelfAttention->forward(ctx, backend, normed_hidden_state, past_bias, mask, relative_position_bucket); + auto ret = SelfAttention->forward(ctx, normed_hidden_state, past_bias, mask, relative_position_bucket); auto output = ret.first; past_bias = ret.second; - x = ggml_add_inplace(ctx, output, x); + x = ggml_add_inplace(ctx->ggml_ctx, output, x); return {x, past_bias}; } }; @@ -656,8 +654,7 @@ struct T5Block : public GGMLBlock { blocks["layer.1"] = std::shared_ptr(new T5LayerFF(model_dim, ff_dim)); } - std::pair forward(struct ggml_context* ctx, - ggml_backend_t backend, + std::pair forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* past_bias = nullptr, struct ggml_tensor* mask = nullptr, @@ -666,7 +663,7 @@ struct T5Block : public GGMLBlock { auto layer_0 = std::dynamic_pointer_cast(blocks["layer.0"]); auto layer_1 = std::dynamic_pointer_cast(blocks["layer.1"]); - auto ret = layer_0->forward(ctx, backend, x, past_bias, mask, relative_position_bucket); + auto ret = layer_0->forward(ctx, x, past_bias, mask, relative_position_bucket); x = ret.first; past_bias = ret.second; x = layer_1->forward(ctx, x); @@ -692,8 +689,7 @@ struct T5Stack : public GGMLBlock { blocks["final_layer_norm"] = std::shared_ptr(new T5LayerNorm(model_dim)); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* past_bias = nullptr, struct ggml_tensor* attention_mask = nullptr, @@ -702,7 +698,7 @@ struct T5Stack : public GGMLBlock { for (int i = 0; i < num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["block." + std::to_string(i)]); - auto ret = block->forward(ctx, backend, x, past_bias, attention_mask, relative_position_bucket); + auto ret = block->forward(ctx, x, past_bias, attention_mask, relative_position_bucket); x = ret.first; past_bias = ret.second; } @@ -740,8 +736,7 @@ struct T5 : public GGMLBlock { params.model_dim)); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* past_bias = nullptr, struct ggml_tensor* attention_mask = nullptr, @@ -752,7 +747,7 @@ struct T5 : public GGMLBlock { auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); auto x = shared->forward(ctx, input_ids); - x = encoder->forward(ctx, backend, x, past_bias, attention_mask, relative_position_bucket); + x = encoder->forward(ctx, x, past_bias, attention_mask, relative_position_bucket); return x; } }; @@ -784,15 +779,14 @@ struct T5Runner : public GGMLRunner { model.get_param_tensors(tensors, prefix); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* input_ids, struct ggml_tensor* relative_position_bucket, struct ggml_tensor* attention_mask = nullptr) { size_t N = input_ids->ne[1]; size_t n_token = input_ids->ne[0]; - auto hidden_states = model.forward(ctx, backend, input_ids, nullptr, attention_mask, relative_position_bucket); // [N, n_token, model_dim] + auto hidden_states = model.forward(ctx, input_ids, nullptr, attention_mask, relative_position_bucket); // [N, n_token, model_dim] return hidden_states; } @@ -818,7 +812,8 @@ struct T5Runner : public GGMLRunner { input_ids->ne[0]); set_backend_tensor_data(relative_position_bucket, relative_position_bucket_vec.data()); - struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, relative_position_bucket, attention_mask); + auto runner_ctx = get_context(); + struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, relative_position_bucket, attention_mask); ggml_build_forward_expand(gf, hidden_states); diff --git a/tae.hpp b/tae.hpp index d630325de..8069de0f2 100644 --- a/tae.hpp +++ b/tae.hpp @@ -29,7 +29,7 @@ class TAEBlock : public UnaryBlock { } } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [n, n_in, h, w] // return: [n, n_out, h, w] @@ -38,9 +38,9 @@ class TAEBlock : public UnaryBlock { auto conv_4 = std::dynamic_pointer_cast(blocks["conv.4"]); auto h = conv_0->forward(ctx, x); - h = ggml_relu_inplace(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); h = conv_2->forward(ctx, h); - h = ggml_relu_inplace(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); h = conv_4->forward(ctx, h); if (n_in != n_out) { @@ -49,8 +49,8 @@ class TAEBlock : public UnaryBlock { x = skip->forward(ctx, x); } - h = ggml_add(ctx, h, x); - h = ggml_relu_inplace(ctx, h); + h = ggml_add(ctx->ggml_ctx, h, x); + h = ggml_relu_inplace(ctx->ggml_ctx, h); return h; } }; @@ -86,7 +86,7 @@ class TinyEncoder : public UnaryBlock { blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1})); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [n, in_channels, h, w] // return: [n, z_channels, h/8, w/8] @@ -136,20 +136,20 @@ class TinyDecoder : public UnaryBlock { blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1})); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override { // z: [n, z_channels, h, w] // return: [n, out_channels, h*8, w*8] - auto h = ggml_scale(ctx, z, 1.0f / 3.0f); - h = ggml_tanh_inplace(ctx, h); - h = ggml_scale(ctx, h, 3.0f); + auto h = ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f); + h = ggml_tanh_inplace(ctx->ggml_ctx, h); + h = ggml_scale(ctx->ggml_ctx, h, 3.0f); for (int i = 0; i < num_blocks * 3 + 10; i++) { if (blocks.find(std::to_string(i)) == blocks.end()) { if (i == 1) { - h = ggml_relu_inplace(ctx, h); + h = ggml_relu_inplace(ctx->ggml_ctx, h); } else { - h = ggml_upscale(ctx, h, 2, GGML_SCALE_MODE_NEAREST); + h = ggml_upscale(ctx->ggml_ctx, h, 2, GGML_SCALE_MODE_NEAREST); } continue; } @@ -180,12 +180,12 @@ class TAESD : public GGMLBlock { } } - struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) { + struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { auto decoder = std::dynamic_pointer_cast(blocks["decoder.layers"]); return decoder->forward(ctx, z); } - struct ggml_tensor* encode(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto encoder = std::dynamic_pointer_cast(blocks["encoder.layers"]); return encoder->forward(ctx, x); } @@ -252,7 +252,8 @@ struct TinyAutoEncoder : public GGMLRunner { struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); z = to_backend(z); - struct ggml_tensor* out = decode_graph ? taesd.decode(compute_ctx, z) : taesd.encode(compute_ctx, z); + auto runner_ctx = get_context(); + struct ggml_tensor* out = decode_graph ? taesd.decode(&runner_ctx, z) : taesd.encode(&runner_ctx, z); ggml_build_forward_expand(gf, out); return gf; } diff --git a/unet.hpp b/unet.hpp index 522a10fbd..80baa4b45 100644 --- a/unet.hpp +++ b/unet.hpp @@ -60,8 +60,7 @@ class SpatialVideoTransformer : public SpatialTransformer { blocks["time_mixer"] = std::shared_ptr(new AlphaBlender()); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* context, int timesteps) { @@ -92,7 +91,7 @@ class SpatialVideoTransformer : public SpatialTransformer { auto time_context = context; // [b*t, n_context, context_dim] auto spatial_context = context; // time_context_first_timestep = time_context[::timesteps] - auto time_context_first_timestep = ggml_view_3d(ctx, + auto time_context_first_timestep = ggml_view_3d(ctx->ggml_ctx, time_context, time_context->ne[0], time_context->ne[1], @@ -100,26 +99,26 @@ class SpatialVideoTransformer : public SpatialTransformer { time_context->nb[1], time_context->nb[2], 0); // [b, n_context, context_dim] - time_context = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, + time_context = ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32, time_context_first_timestep->ne[0], time_context_first_timestep->ne[1], time_context_first_timestep->ne[2] * h * w); - time_context = ggml_repeat(ctx, time_context_first_timestep, time_context); // [b*h*w, n_context, context_dim] + time_context = ggml_repeat(ctx->ggml_ctx, time_context_first_timestep, time_context); // [b*h*w, n_context, context_dim] x = norm->forward(ctx, x); x = proj_in->forward(ctx, x); // [N, inner_dim, h, w] - x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim] - x = ggml_reshape_3d(ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim] + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim] + x = ggml_reshape_3d(ctx->ggml_ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim] - auto num_frames = ggml_arange(ctx, 0, timesteps, 1); + auto num_frames = ggml_arange(ctx->ggml_ctx, 0, timesteps, 1); // since b is 1, no need to do repeat - auto t_emb = ggml_ext_timestep_embedding(ctx, num_frames, in_channels, max_time_embed_period); // [N, in_channels] + auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, num_frames, in_channels, max_time_embed_period); // [N, in_channels] auto emb = time_pos_embed_0->forward(ctx, t_emb); - emb = ggml_silu_inplace(ctx, emb); - emb = time_pos_embed_2->forward(ctx, emb); // [N, in_channels] - emb = ggml_reshape_3d(ctx, emb, emb->ne[0], 1, emb->ne[1]); // [N, 1, in_channels] + emb = ggml_silu_inplace(ctx->ggml_ctx, emb); + emb = time_pos_embed_2->forward(ctx, emb); // [N, in_channels] + emb = ggml_reshape_3d(ctx->ggml_ctx, emb, emb->ne[0], 1, emb->ne[1]); // [N, 1, in_channels] for (int i = 0; i < depth; i++) { std::string transformer_name = "transformer_blocks." + std::to_string(i); @@ -128,11 +127,11 @@ class SpatialVideoTransformer : public SpatialTransformer { auto block = std::dynamic_pointer_cast(blocks[transformer_name]); auto mix_block = std::dynamic_pointer_cast(blocks[time_stack_name]); - x = block->forward(ctx, backend, x, spatial_context); // [N, h * w, inner_dim] + x = block->forward(ctx, x, spatial_context); // [N, h * w, inner_dim] // in_channels == inner_dim auto x_mix = x; - x_mix = ggml_add(ctx, x_mix, emb); // [N, h * w, inner_dim] + x_mix = ggml_add(ctx->ggml_ctx, x_mix, emb); // [N, h * w, inner_dim] int64_t N = x_mix->ne[2]; int64_t T = timesteps; @@ -140,26 +139,26 @@ class SpatialVideoTransformer : public SpatialTransformer { int64_t S = x_mix->ne[1]; int64_t C = x_mix->ne[0]; - x_mix = ggml_reshape_4d(ctx, x_mix, C, S, T, B); // (b t) s c -> b t s c - x_mix = ggml_cont(ctx, ggml_permute(ctx, x_mix, 0, 2, 1, 3)); // b t s c -> b s t c - x_mix = ggml_reshape_3d(ctx, x_mix, C, T, S * B); // b s t c -> (b s) t c + x_mix = ggml_reshape_4d(ctx->ggml_ctx, x_mix, C, S, T, B); // (b t) s c -> b t s c + x_mix = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_mix, 0, 2, 1, 3)); // b t s c -> b s t c + x_mix = ggml_reshape_3d(ctx->ggml_ctx, x_mix, C, T, S * B); // b s t c -> (b s) t c - x_mix = mix_block->forward(ctx, backend, x_mix, time_context); // [B * h * w, T, inner_dim] + x_mix = mix_block->forward(ctx, x_mix, time_context); // [B * h * w, T, inner_dim] - x_mix = ggml_reshape_4d(ctx, x_mix, C, T, S, B); // (b s) t c -> b s t c - x_mix = ggml_cont(ctx, ggml_permute(ctx, x_mix, 0, 2, 1, 3)); // b s t c -> b t s c - x_mix = ggml_reshape_3d(ctx, x_mix, C, S, T * B); // b t s c -> (b t) s c + x_mix = ggml_reshape_4d(ctx->ggml_ctx, x_mix, C, T, S, B); // (b s) t c -> b s t c + x_mix = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_mix, 0, 2, 1, 3)); // b s t c -> b t s c + x_mix = ggml_reshape_3d(ctx->ggml_ctx, x_mix, C, S, T * B); // b t s c -> (b t) s c x = time_mixer->forward(ctx, x, x_mix); // [N, h * w, inner_dim] } - x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w] - x = ggml_reshape_4d(ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w] + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w] + x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w] // proj_out x = proj_out->forward(ctx, x); // [N, in_channels, h, w] - x = ggml_add(ctx, x, x_in); + x = ggml_add(ctx->ggml_ctx, x, x_in); return x; } }; @@ -377,7 +376,7 @@ class UnetModelBlock : public GGMLBlock { } struct ggml_tensor* resblock_forward(std::string name, - struct ggml_context* ctx, + GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* emb, int num_video_frames) { @@ -393,24 +392,22 @@ class UnetModelBlock : public GGMLBlock { } struct ggml_tensor* attention_layer_forward(std::string name, - struct ggml_context* ctx, - ggml_backend_t backend, + GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* context, int timesteps) { if (version == VERSION_SVD) { auto block = std::dynamic_pointer_cast(blocks[name]); - return block->forward(ctx, backend, x, context, timesteps); + return block->forward(ctx, x, context, timesteps); } else { auto block = std::dynamic_pointer_cast(blocks[name]); - return block->forward(ctx, backend, x, context); + return block->forward(ctx, x, context); } } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, @@ -427,20 +424,20 @@ class UnetModelBlock : public GGMLBlock { // return: [N, out_channels, h, w] if (context != nullptr) { if (context->ne[2] != x->ne[3]) { - context = ggml_repeat(ctx, context, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3])); + context = ggml_repeat(ctx->ggml_ctx, context, ggml_new_tensor_3d(ctx->ggml_ctx, GGML_TYPE_F32, context->ne[0], context->ne[1], x->ne[3])); } } if (c_concat != nullptr) { if (c_concat->ne[3] != x->ne[3]) { - c_concat = ggml_repeat(ctx, c_concat, x); + c_concat = ggml_repeat(ctx->ggml_ctx, c_concat, x); } - x = ggml_concat(ctx, x, c_concat, 2); + x = ggml_concat(ctx->ggml_ctx, x, c_concat, 2); } if (y != nullptr) { if (y->ne[1] != x->ne[3]) { - y = ggml_repeat(ctx, y, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, y->ne[0], x->ne[3])); + y = ggml_repeat(ctx->ggml_ctx, y, ggml_new_tensor_2d(ctx->ggml_ctx, GGML_TYPE_F32, y->ne[0], x->ne[3])); } } @@ -451,10 +448,10 @@ class UnetModelBlock : public GGMLBlock { auto out_0 = std::dynamic_pointer_cast(blocks["out.0"]); auto out_2 = std::dynamic_pointer_cast(blocks["out.2"]); - auto t_emb = ggml_ext_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels] + auto t_emb = ggml_ext_timestep_embedding(ctx->ggml_ctx, timesteps, model_channels); // [N, model_channels] auto emb = time_embed_0->forward(ctx, t_emb); - emb = ggml_silu_inplace(ctx, emb); + emb = ggml_silu_inplace(ctx->ggml_ctx, emb); emb = time_embed_2->forward(ctx, emb); // [N, time_embed_dim] // SDXL/SVD @@ -463,10 +460,10 @@ class UnetModelBlock : public GGMLBlock { auto label_embed_2 = std::dynamic_pointer_cast(blocks["label_emb.0.2"]); auto label_emb = label_embed_0->forward(ctx, y); - label_emb = ggml_silu_inplace(ctx, label_emb); + label_emb = ggml_silu_inplace(ctx->ggml_ctx, label_emb); label_emb = label_embed_2->forward(ctx, label_emb); // [N, time_embed_dim] - emb = ggml_add(ctx, emb, label_emb); // [N, time_embed_dim] + emb = ggml_add(ctx->ggml_ctx, emb, label_emb); // [N, time_embed_dim] } // input_blocks @@ -489,7 +486,7 @@ class UnetModelBlock : public GGMLBlock { h = resblock_forward(name, ctx, h, emb, num_video_frames); // [N, mult*model_channels, h, w] if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1"; - h = attention_layer_forward(name, ctx, backend, h, context, num_video_frames); // [N, mult*model_channels, h, w] + h = attention_layer_forward(name, ctx, h, context, num_video_frames); // [N, mult*model_channels, h, w] } hs.push_back(h); } @@ -513,13 +510,13 @@ class UnetModelBlock : public GGMLBlock { if (version != VERSION_SD1_TINY_UNET) { h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] if (version != VERSION_SDXL_SSD1B) { - h = attention_layer_forward("middle_block.1", ctx, backend, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] - h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] + h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] + h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] } } if (controls.size() > 0) { - auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength); - h = ggml_add(ctx, h, cs); // middle control + auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[controls.size() - 1], control_strength); + h = ggml_add(ctx->ggml_ctx, h, cs); // middle control } int control_offset = controls.size() - 2; @@ -531,12 +528,12 @@ class UnetModelBlock : public GGMLBlock { hs.pop_back(); if (controls.size() > 0) { - auto cs = ggml_scale_inplace(ctx, controls[control_offset], control_strength); - h_skip = ggml_add(ctx, h_skip, cs); // control net condition + auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[control_offset], control_strength); + h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition control_offset--; } - h = ggml_concat(ctx, h, h_skip, 2); + h = ggml_concat(ctx->ggml_ctx, h, h_skip, 2); std::string name = "output_blocks." + std::to_string(output_block_idx) + ".0"; @@ -546,7 +543,7 @@ class UnetModelBlock : public GGMLBlock { if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1"; - h = attention_layer_forward(name, ctx, backend, h, context, num_video_frames); + h = attention_layer_forward(name, ctx, h, context, num_video_frames); up_sample_idx++; } @@ -572,7 +569,7 @@ class UnetModelBlock : public GGMLBlock { // out h = out_0->forward(ctx, h); - h = ggml_silu_inplace(ctx, h); + h = ggml_silu_inplace(ctx->ggml_ctx, h); h = out_2->forward(ctx, h); ggml_set_name(h, "bench-end"); return h; // [N, out_channels, h, w] @@ -636,8 +633,9 @@ struct UNetModelRunner : public GGMLRunner { controls[i] = to_backend(controls[i]); } - struct ggml_tensor* out = unet.forward(compute_ctx, - runtime_backend, + auto runner_ctx = get_context(); + + struct ggml_tensor* out = unet.forward(&runner_ctx, x, timesteps, context, diff --git a/vae.hpp b/vae.hpp index e55bdd38b..cad5961de 100644 --- a/vae.hpp +++ b/vae.hpp @@ -30,7 +30,7 @@ class ResnetBlock : public UnaryBlock { } } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [N, in_channels, h, w] // t_emb is always None auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); @@ -40,12 +40,12 @@ class ResnetBlock : public UnaryBlock { auto h = x; h = norm1->forward(ctx, h); - h = ggml_silu_inplace(ctx, h); // swish + h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish h = conv1->forward(ctx, h); // return h; h = norm2->forward(ctx, h); - h = ggml_silu_inplace(ctx, h); // swish + h = ggml_silu_inplace(ctx->ggml_ctx, h); // swish // dropout, skip for inference h = conv2->forward(ctx, h); @@ -56,7 +56,7 @@ class ResnetBlock : public UnaryBlock { x = nin_shortcut->forward(ctx, x); // [N, out_channels, h, w] } - h = ggml_add(ctx, h, x); + h = ggml_add(ctx->ggml_ctx, h, x); return h; // [N, out_channels, h, w] } }; @@ -76,7 +76,7 @@ class AttnBlock : public UnaryBlock { blocks["proj_out"] = std::shared_ptr(new Conv2d(in_channels, in_channels, {1, 1})); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [N, in_channels, h, w] auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto q_proj = std::dynamic_pointer_cast(blocks["q"]); @@ -91,25 +91,25 @@ class AttnBlock : public UnaryBlock { const int64_t h = h_->ne[1]; const int64_t w = h_->ne[0]; - auto q = q_proj->forward(ctx, h_); // [N, in_channels, h, w] - q = ggml_cont(ctx, ggml_permute(ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels] - q = ggml_reshape_3d(ctx, q, c, h * w, n); // [N, h * w, in_channels] + auto q = q_proj->forward(ctx, h_); // [N, in_channels, h, w] + q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels] + q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [N, h * w, in_channels] - auto k = k_proj->forward(ctx, h_); // [N, in_channels, h, w] - k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels] - k = ggml_reshape_3d(ctx, k, c, h * w, n); // [N, h * w, in_channels] + auto k = k_proj->forward(ctx, h_); // [N, in_channels, h, w] + k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels] + k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels] - auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] - v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w] + auto v = v_proj->forward(ctx, h_); // [N, in_channels, h, w] + v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w] - h_ = ggml_ext_attention(ctx, q, k, v, false); // [N, h * w, in_channels] + h_ = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [N, h * w, in_channels] - h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w] - h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w] + h_ = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w] + h_ = ggml_reshape_4d(ctx->ggml_ctx, h_, w, h, c, n); // [N, in_channels, h, w] h_ = proj_out->forward(ctx, h_); // [N, in_channels, h, w] - h_ = ggml_add(ctx, h_, x); + h_ = ggml_add(ctx->ggml_ctx, h_, x); return h_; } }; @@ -133,7 +133,7 @@ class AE3DConv : public Conv2d { kernel_padding)); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // timesteps always None // skip_video always False @@ -152,12 +152,12 @@ class AE3DConv : public Conv2d { int64_t H = x->ne[1]; int64_t W = x->ne[0]; - x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) - x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW] - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) - x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w - return x; // [B*T, OC, OH, OW] + x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) + x = time_mix_conv->forward(ctx, x); // [B, OC, T, OH * OW] + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) + x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w + return x; // [B*T, OC, OH, OW] } }; @@ -182,7 +182,7 @@ class VideoResnetBlock : public ResnetBlock { blocks["time_stack"] = std::shared_ptr(new ResBlock(out_channels, 0, out_channels, {video_kernel_size, 1}, 3, false, true)); } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [N, in_channels, h, w] aka [b*t, in_channels, h, w] // return: [N, out_channels, h, w] aka [b*t, out_channels, h, w] // t_emb is always None @@ -199,19 +199,19 @@ class VideoResnetBlock : public ResnetBlock { int64_t H = x->ne[1]; int64_t W = x->ne[0]; - x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) + x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w) + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w) auto x_mix = x; x = time_stack->forward(ctx, x); // b t c (h w) float alpha = get_alpha(); - x = ggml_add(ctx, - ggml_scale(ctx, x, alpha), - ggml_scale(ctx, x_mix, 1.0f - alpha)); + x = ggml_add(ctx->ggml_ctx, + ggml_scale(ctx->ggml_ctx, x, alpha), + ggml_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha)); - x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) - x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w + x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) + x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w return x; } @@ -271,7 +271,7 @@ class Encoder : public GGMLBlock { blocks["conv_out"] = std::shared_ptr(new Conv2d(block_in, double_z ? z_channels * 2 : z_channels, {3, 3}, {1, 1}, {1, 1})); } - virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [N, in_channels, h, w] auto conv_in = std::dynamic_pointer_cast(blocks["conv_in"]); @@ -307,8 +307,8 @@ class Encoder : public GGMLBlock { // end h = norm_out->forward(ctx, h); - h = ggml_silu_inplace(ctx, h); // nonlinearity/swish - h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w] + h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish + h = conv_out->forward(ctx, h); // [N, z_channels*2, h, w] return h; } }; @@ -388,7 +388,7 @@ class Decoder : public GGMLBlock { blocks["conv_out"] = get_conv_out(block_in, out_ch, {3, 3}, {1, 1}, {1, 1}); } - virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) { + virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) { // z: [N, z_channels, h, w] // alpha is always 0 // merge_strategy is always learned @@ -429,8 +429,8 @@ class Decoder : public GGMLBlock { } h = norm_out->forward(ctx, h); - h = ggml_silu_inplace(ctx, h); // nonlinearity/swish - h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8] + h = ggml_silu_inplace(ctx->ggml_ctx, h); // nonlinearity/swish + h = conv_out->forward(ctx, h); // [N, out_ch, h*8, w*8] return h; } }; @@ -493,7 +493,7 @@ class AutoencodingEngine : public GGMLBlock { } } - struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) { + struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z) { // z: [N, z_channels, h, w] if (use_quant) { auto post_quant_conv = std::dynamic_pointer_cast(blocks["post_quant_conv"]); @@ -507,7 +507,7 @@ class AutoencodingEngine : public GGMLBlock { return h; } - struct ggml_tensor* encode(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) { // x: [N, in_channels, h, w] auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); @@ -607,7 +607,9 @@ struct AutoEncoderKL : public VAE { z = to_backend(z); - struct ggml_tensor* out = decode_graph ? ae.decode(compute_ctx, z) : ae.encode(compute_ctx, z); + auto runner_ctx = get_context(); + + struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z); ggml_build_forward_expand(gf, out); diff --git a/wan.hpp b/wan.hpp index 672e6b4ea..2815ee62b 100644 --- a/wan.hpp +++ b/wan.hpp @@ -54,7 +54,7 @@ namespace WAN { dilation(std::move(dilation)), bias(bias) {} - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* cache_x = nullptr) { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* cache_x = nullptr) { // x: [N*IC, ID, IH, IW] // result: x: [N*OC, ID, IH, IW] struct ggml_tensor* w = params["weight"]; @@ -71,12 +71,12 @@ namespace WAN { int rp2 = 0; if (cache_x != nullptr && lp2 > 0) { - x = ggml_concat(ctx, cache_x, x, 2); + x = ggml_concat(ctx->ggml_ctx, cache_x, x, 2); lp2 -= (int)cache_x->ne[2]; } - x = ggml_pad_ext(ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0); - return ggml_ext_conv_3d(ctx, x, w, b, in_channels, + x = ggml_pad_ext(ctx->ggml_ctx, x, lp0, rp0, lp1, rp1, lp2, rp2, 0, 0); + return ggml_ext_conv_3d(ctx->ggml_ctx, x, w, b, in_channels, std::get<2>(stride), std::get<1>(stride), std::get<0>(stride), 0, 0, 0, std::get<2>(dilation), std::get<1>(dilation), std::get<0>(dilation)); @@ -96,15 +96,15 @@ namespace WAN { RMS_norm(int64_t dim) : dim(dim) {} - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) override { + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [N*IC, ID, IH, IW], IC == dim // assert N == 1 struct ggml_tensor* w = params["gamma"]; - auto h = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC] - h = ggml_rms_norm(ctx, h, 1e-12); - h = ggml_mul(ctx, h, w); - h = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, h, 1, 2, 3, 0)); + auto h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 3, 0, 1, 2)); // [ID, IH, IW, N*IC] + h = ggml_rms_norm(ctx->ggml_ctx, h, 1e-12); + h = ggml_mul(ctx->ggml_ctx, h, w); + h = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, h, 1, 2, 3, 0)); return h; } @@ -143,7 +143,7 @@ namespace WAN { } } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, int64_t b, std::vector& feat_cache, @@ -165,16 +165,16 @@ namespace WAN { } else { auto time_conv = std::dynamic_pointer_cast(blocks["time_conv"]); - auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]); + auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { // chunk_idx >= 2 // cache last frame of last two chunk - cache_x = ggml_concat(ctx, - ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), + cache_x = ggml_concat(ctx->ggml_ctx, + ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), cache_x, 2); } if (chunk_idx == 1 && cache_x->ne[2] < 2) { // Rep - cache_x = ggml_pad_ext(ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0); + cache_x = ggml_pad_ext(ctx->ggml_ctx, cache_x, 0, 0, 0, 0, (int)cache_x->ne[2], 0, 0, 0); // aka cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device),cache_x],dim=2) } if (chunk_idx == 1) { @@ -183,9 +183,9 @@ namespace WAN { x = time_conv->forward(ctx, x, feat_cache[idx]); } feat_cache[idx] = cache_x; - x = ggml_reshape_4d(ctx, x, w * h, t, c, 2); // (2, c, t, h*w) - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w) - x = ggml_reshape_4d(ctx, x, w, h, 2 * t, c); // (c, t*2, h, w) + x = ggml_reshape_4d(ctx->ggml_ctx, x, w * h, t, c, 2); // (2, c, t, h*w) + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 3, 1, 2)); // (c, t, 2, h*w) + x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, 2 * t, c); // (c, t*2, h, w) } } } @@ -194,18 +194,18 @@ namespace WAN { if (mode != "none") { auto resample_1 = std::dynamic_pointer_cast(blocks["resample.1"]); - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w) + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (t, c, h, w) if (mode == "upsample2d") { - x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); + x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST); } else if (mode == "upsample3d") { - x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); + x = ggml_upscale(ctx->ggml_ctx, x, 2, GGML_SCALE_MODE_NEAREST); } else if (mode == "downsample2d") { - x = ggml_pad(ctx, x, 1, 1, 0, 0); + x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0); } else if (mode == "downsample3d") { - x = ggml_pad(ctx, x, 1, 1, 0, 0); + x = ggml_pad(ctx->ggml_ctx, x, 1, 1, 0, 0); } x = resample_1->forward(ctx, x); - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w) + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (c, t, h, w) } if (mode == "downsample3d") { @@ -217,9 +217,9 @@ namespace WAN { } else { auto time_conv = std::dynamic_pointer_cast(blocks["time_conv"]); - auto cache_x = ggml_ext_slice(ctx, x, 2, -1, x->ne[2]); - x = ggml_concat(ctx, - ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), + auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -1, x->ne[2]); + x = ggml_concat(ctx->ggml_ctx, + ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), x, 2); x = time_conv->forward(ctx, x); @@ -249,7 +249,7 @@ namespace WAN { GGML_ASSERT(in_channels * factor % out_channels == 0); group_size = in_channels * factor / out_channels; } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, int64_t B = 1) { // x: [B*IC, T, H, W] @@ -262,20 +262,20 @@ namespace WAN { int64_t pad_t = (factor_t - T % factor_t) % factor_t; - x = ggml_pad_ext(ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0); + x = ggml_pad_ext(ctx->ggml_ctx, x, 0, 0, 0, 0, pad_t, 0, 0, 0); T = x->ne[2]; - x = ggml_reshape_4d(ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W] - x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, factor_t, T/factor_t, H*W] - x = ggml_reshape_4d(ctx, x, W, factor_s, (H / factor_s) * (T / factor_t), factor_t * C); // [C*factor_t, T/factor_t*H/factor_s, factor_s, W] - x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, factor_s, T/factor_t*H/factor_s, W] - x = ggml_reshape_4d(ctx, x, factor_s, W / factor_s, (H / factor_s) * (T / factor_t), factor_s * factor_t * C); // [C*factor_t*factor_s, T/factor_t*H/factor_s, W/factor_s, factor_s] - x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 2, 0, 3)); // [C*factor_t*factor_s, factor_s, T/factor_t*H/factor_s, W/factor_s] - x = ggml_reshape_3d(ctx, x, (W / factor_s) * (H / factor_s) * (T / factor_t), group_size, out_channels); // [out_channels, group_size, T/factor_t*H/factor_s*W/factor_s] - - x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 0, 2, 3)); // [out_channels, T/factor_t*H/factor_s*W/factor_s, group_size] - x = ggml_mean(ctx, x); // [out_channels, T/factor_t*H/factor_s*W/factor_s, 1] - x = ggml_reshape_4d(ctx, x, W / factor_s, H / factor_s, T / factor_t, out_channels); + x = ggml_reshape_4d(ctx->ggml_ctx, x, W * H, factor_t, T / factor_t, C); // [C, T/factor_t, factor_t, H*W] + x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // [C, factor_t, T/factor_t, H*W] + x = ggml_reshape_4d(ctx->ggml_ctx, x, W, factor_s, (H / factor_s) * (T / factor_t), factor_t * C); // [C*factor_t, T/factor_t*H/factor_s, factor_s, W] + x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // [C*factor_t, factor_s, T/factor_t*H/factor_s, W] + x = ggml_reshape_4d(ctx->ggml_ctx, x, factor_s, W / factor_s, (H / factor_s) * (T / factor_t), factor_s * factor_t * C); // [C*factor_t*factor_s, T/factor_t*H/factor_s, W/factor_s, factor_s] + x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 2, 0, 3)); // [C*factor_t*factor_s, factor_s, T/factor_t*H/factor_s, W/factor_s] + x = ggml_reshape_3d(ctx->ggml_ctx, x, (W / factor_s) * (H / factor_s) * (T / factor_t), group_size, out_channels); // [out_channels, group_size, T/factor_t*H/factor_s*W/factor_s] + + x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [out_channels, T/factor_t*H/factor_s*W/factor_s, group_size] + x = ggml_mean(ctx->ggml_ctx, x); // [out_channels, T/factor_t*H/factor_s*W/factor_s, 1] + x = ggml_reshape_4d(ctx->ggml_ctx, x, W / factor_s, H / factor_s, T / factor_t, out_channels); return x; } }; @@ -296,7 +296,7 @@ namespace WAN { GGML_ASSERT(out_channels * factor % in_channels == 0); repeats = out_channels * factor / in_channels; } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, bool first_chunk = false, int64_t B = 1) { @@ -310,21 +310,21 @@ namespace WAN { auto x_ = x; for (int64_t i = 1; i < repeats; i++) { - x = ggml_concat(ctx, x, x_, 2); + x = ggml_concat(ctx->ggml_ctx, x, x_, 2); } C = out_channels; - x = ggml_reshape_4d(ctx, x, W, H * T, factor_s, factor_s * factor_t * C); // [C*factor_t*factor_s, factor_s, T*H, W] - x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 2, 0, 1, 3)); // [C*factor_t*factor_s, T*H, W, factor_s] - x = ggml_reshape_4d(ctx, x, factor_s * W, H * T, factor_s, factor_t * C); // [C*factor_t, factor_s, T*H, W*factor_s] - x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C*factor_t, T*H, factor_s, W*factor_s] - x = ggml_reshape_4d(ctx, x, factor_s * W * factor_s * H, T, factor_t, C); // [C, factor_t, T, H*factor_s*W*factor_s] - x = ggml_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 2, 1, 3)); // [C, T, factor_t, H*factor_s*W*factor_s] - x = ggml_reshape_4d(ctx, x, factor_s * W, factor_s * H, factor_t * T, C); // [C, T*factor_t, H*factor_s, W*factor_s] + x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H * T, factor_s, factor_s * factor_t * C); // [C*factor_t*factor_s, factor_s, T*H, W] + x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 2, 0, 1, 3)); // [C*factor_t*factor_s, T*H, W, factor_s] + x = ggml_reshape_4d(ctx->ggml_ctx, x, factor_s * W, H * T, factor_s, factor_t * C); // [C*factor_t, factor_s, T*H, W*factor_s] + x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // [C*factor_t, T*H, factor_s, W*factor_s] + x = ggml_reshape_4d(ctx->ggml_ctx, x, factor_s * W * factor_s * H, T, factor_t, C); // [C, factor_t, T, H*factor_s*W*factor_s] + x = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // [C, T, factor_t, H*factor_s*W*factor_s] + x = ggml_reshape_4d(ctx->ggml_ctx, x, factor_s * W, factor_s * H, factor_t * T, C); // [C, T*factor_t, H*factor_s, W*factor_s] if (first_chunk) { - x = ggml_ext_slice(ctx, x, 2, factor_t - 1, x->ne[2]); + x = ggml_ext_slice(ctx->ggml_ctx, x, 2, factor_t - 1, x->ne[2]); } return x; @@ -351,7 +351,7 @@ namespace WAN { } } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, int64_t b, std::vector& feat_cache, @@ -374,11 +374,11 @@ namespace WAN { if (feat_cache.size() > 0) { int idx = feat_idx; - auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]); + auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { // cache last frame of last two chunk - cache_x = ggml_concat(ctx, - ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), + cache_x = ggml_concat(ctx->ggml_ctx, + ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), cache_x, 2); } @@ -388,13 +388,13 @@ namespace WAN { feat_idx += 1; } } else if (i == 1 || i == 4) { - x = ggml_silu(ctx, x); + x = ggml_silu(ctx->ggml_ctx, x); } else { // i == 5 // nn.Dropout(), ignore } } - x = ggml_add(ctx, x, h); + x = ggml_add(ctx->ggml_ctx, x, h); return x; } }; @@ -425,7 +425,7 @@ namespace WAN { } } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, int64_t b, std::vector& feat_cache, @@ -453,7 +453,7 @@ namespace WAN { auto shortcut = avg_shortcut->forward(ctx, x_copy, b); - x = ggml_add(ctx, x, shortcut); + x = ggml_add(ctx->ggml_ctx, x, shortcut); return x; } @@ -487,7 +487,7 @@ namespace WAN { } } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, int64_t b, std::vector& feat_cache, @@ -513,7 +513,7 @@ namespace WAN { auto avg_shortcut = std::dynamic_pointer_cast(blocks["avg_shortcut"]); auto shortcut = avg_shortcut->forward(ctx, x_copy, chunk_idx == 0, b); - x = ggml_add(ctx, x, shortcut); + x = ggml_add(ctx->ggml_ctx, x, shortcut); } return x; @@ -532,7 +532,7 @@ namespace WAN { blocks["proj"] = std::shared_ptr(new Conv2d(dim, dim, {1, 1})); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, int64_t b) { // x: [b*c, t, h, w] @@ -545,7 +545,7 @@ namespace WAN { x = norm->forward(ctx, x); - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (t, c, h, w) + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (t, c, h, w) const int64_t n = x->ne[3]; const int64_t c = x->ne[2]; @@ -553,31 +553,31 @@ namespace WAN { const int64_t w = x->ne[0]; auto qkv = to_qkv->forward(ctx, x); - auto qkv_vec = split_image_qkv(ctx, qkv); + auto qkv_vec = split_image_qkv(ctx->ggml_ctx, qkv); auto q = qkv_vec[0]; - q = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, q, 2, 0, 1, 3)); // [t, h, w, c] - q = ggml_reshape_3d(ctx, q, c, h * w, n); // [t, h * w, c] + q = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 2, 0, 1, 3)); // [t, h, w, c] + q = ggml_reshape_3d(ctx->ggml_ctx, q, c, h * w, n); // [t, h * w, c] auto k = qkv_vec[1]; - k = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, k, 2, 0, 1, 3)); // [t, h, w, c] - k = ggml_reshape_3d(ctx, k, c, h * w, n); // [t, h * w, c] + k = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 2, 0, 1, 3)); // [t, h, w, c] + k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [t, h * w, c] auto v = qkv_vec[2]; - v = ggml_reshape_3d(ctx, v, h * w, c, n); // [t, c, h * w] + v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w] - x = ggml_ext_attention(ctx, q, k, v, false); // [t, h * w, c] + x = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [t, h * w, c] // v = ggml_cont(ctx, ggml_ext_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c] // x = ggml_ext_attention_ext(ctx, q, k, v, q->ne[2], nullptr, false, false, true); - x = ggml_ext_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [t, c, h * w] - x = ggml_reshape_4d(ctx, x, w, h, c, n); // [t, c, h, w] + x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w] + x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w] x = proj->forward(ctx, x); - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 0, 1, 3, 2)); // (c, t, h, w) + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 0, 1, 3, 2)); // (c, t, h, w) - x = ggml_add(ctx, x, identity); + x = ggml_add(ctx->ggml_ctx, x, identity); return x; } }; @@ -655,7 +655,7 @@ namespace WAN { blocks["head.2"] = std::shared_ptr(new CausalConv3d(out_dim, z_dim, {3, 3, 3}, {1, 1, 1}, {1, 1, 1})); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, int64_t b, std::vector& feat_cache, @@ -673,11 +673,11 @@ namespace WAN { // conv1 if (feat_cache.size() > 0) { int idx = feat_idx; - auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]); + auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { // cache last frame of last two chunk - cache_x = ggml_concat(ctx, - ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), + cache_x = ggml_concat(ctx->ggml_ctx, + ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), cache_x, 2); } @@ -722,14 +722,14 @@ namespace WAN { // head x = head_0->forward(ctx, x); - x = ggml_silu(ctx, x); + x = ggml_silu(ctx->ggml_ctx, x); if (feat_cache.size() > 0) { int idx = feat_idx; - auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]); + auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { // cache last frame of last two chunk - cache_x = ggml_concat(ctx, - ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), + cache_x = ggml_concat(ctx->ggml_ctx, + ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), cache_x, 2); } @@ -826,7 +826,7 @@ namespace WAN { } } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, int64_t b, std::vector& feat_cache, @@ -844,11 +844,11 @@ namespace WAN { // conv1 if (feat_cache.size() > 0) { int idx = feat_idx; - auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]); + auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { // cache last frame of last two chunk - cache_x = ggml_concat(ctx, - ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), + cache_x = ggml_concat(ctx->ggml_ctx, + ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), cache_x, 2); } @@ -893,14 +893,14 @@ namespace WAN { // head x = head_0->forward(ctx, x); - x = ggml_silu(ctx, x); + x = ggml_silu(ctx->ggml_ctx, x); if (feat_cache.size() > 0) { int idx = feat_idx; - auto cache_x = ggml_ext_slice(ctx, x, 2, -CACHE_T, x->ne[2]); + auto cache_x = ggml_ext_slice(ctx->ggml_ctx, x, 2, -CACHE_T, x->ne[2]); if (cache_x->ne[2] < 2 && feat_cache[idx] != nullptr) { // cache last frame of last two chunk - cache_x = ggml_concat(ctx, - ggml_ext_slice(ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), + cache_x = ggml_concat(ctx->ggml_ctx, + ggml_ext_slice(ctx->ggml_ctx, feat_cache[idx], 2, -1, feat_cache[idx]->ne[2]), cache_x, 2); } @@ -1015,7 +1015,7 @@ namespace WAN { return x; } - struct ggml_tensor* encode(struct ggml_context* ctx, + struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x, int64_t b = 1) { // x: [b*c, t, h, w] @@ -1025,7 +1025,7 @@ namespace WAN { clear_cache(); if (wan2_2) { - x = patchify(ctx, x, 2, b); + x = patchify(ctx->ggml_ctx, x, 2, b); } auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); @@ -1037,21 +1037,21 @@ namespace WAN { for (int i = 0; i < iter_; i++) { _enc_conv_idx = 0; if (i == 0) { - auto in = ggml_ext_slice(ctx, x, 2, 0, 1); // [b*c, 1, h, w] + auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, 0, 1); // [b*c, 1, h, w] out = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i); } else { - auto in = ggml_ext_slice(ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w] + auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, 1 + 4 * (i - 1), 1 + 4 * i); // [b*c, 4, h, w] auto out_ = encoder->forward(ctx, in, b, _enc_feat_map, _enc_conv_idx, i); - out = ggml_concat(ctx, out, out_, 2); + out = ggml_concat(ctx->ggml_ctx, out, out_, 2); } } out = conv1->forward(ctx, out); - auto mu = ggml_ext_chunk(ctx, out, 2, 3)[0]; + auto mu = ggml_ext_chunk(ctx->ggml_ctx, out, 2, 3)[0]; clear_cache(); return mu; } - struct ggml_tensor* decode(struct ggml_context* ctx, + struct ggml_tensor* decode(GGMLRunnerContext* ctx, struct ggml_tensor* z, int64_t b = 1) { // z: [b*c, t, h, w] @@ -1068,22 +1068,22 @@ namespace WAN { for (int64_t i = 0; i < iter_; i++) { _conv_idx = 0; if (i == 0) { - auto in = ggml_ext_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] + auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w] out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); } else { - auto in = ggml_ext_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] + auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w] auto out_ = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); - out = ggml_concat(ctx, out, out_, 2); + out = ggml_concat(ctx->ggml_ctx, out, out_, 2); } } if (wan2_2) { - out = unpatchify(ctx, out, 2, b); + out = unpatchify(ctx->ggml_ctx, out, 2, b); } clear_cache(); return out; } - struct ggml_tensor* decode_partial(struct ggml_context* ctx, + struct ggml_tensor* decode_partial(GGMLRunnerContext* ctx, struct ggml_tensor* z, int64_t i, int64_t b = 1) { @@ -1094,11 +1094,11 @@ namespace WAN { auto conv2 = std::dynamic_pointer_cast(blocks["conv2"]); auto x = conv2->forward(ctx, z); - auto in = ggml_ext_slice(ctx, x, 2, i, i + 1); // [b*c, 1, h, w] + auto in = ggml_ext_slice(ctx->ggml_ctx, x, 2, i, i + 1); // [b*c, 1, h, w] _conv_idx = 0; auto out = decoder->forward(ctx, in, b, _feat_map, _conv_idx, i); if (wan2_2) { - out = unpatchify(ctx, out, 2, b); + out = unpatchify(ctx->ggml_ctx, out, 2, b); } return out; } @@ -1131,7 +1131,9 @@ namespace WAN { z = to_backend(z); - struct ggml_tensor* out = decode_graph ? ae.decode(compute_ctx, z) : ae.encode(compute_ctx, z); + auto runner_ctx = get_context(); + + struct ggml_tensor* out = decode_graph ? ae.decode(&runner_ctx, z) : ae.encode(&runner_ctx, z); ggml_build_forward_expand(gf, out); @@ -1150,7 +1152,9 @@ namespace WAN { z = to_backend(z); - struct ggml_tensor* out = decode_graph ? ae.decode_partial(compute_ctx, z, i) : ae.encode(compute_ctx, z); + auto runner_ctx = get_context(); + + struct ggml_tensor* out = decode_graph ? ae.decode_partial(&runner_ctx, z, i) : ae.encode(&runner_ctx, z); for (int64_t feat_idx = 0; feat_idx < ae._feat_map.size(); feat_idx++) { ggml_tensor* feat_cache = ae._feat_map[feat_idx]; @@ -1307,8 +1311,7 @@ namespace WAN { } } - virtual struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask = nullptr) { @@ -1331,11 +1334,11 @@ namespace WAN { k = norm_k->forward(ctx, k); auto v = v_proj->forward(ctx, x); // [N, n_token, n_head*d_head] - q = ggml_reshape_4d(ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] - k = ggml_reshape_4d(ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] - v = ggml_reshape_4d(ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] + q = ggml_reshape_4d(ctx->ggml_ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] + k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] + v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] - x = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] + x = Rope::attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim] return x; @@ -1350,8 +1353,7 @@ namespace WAN { float eps = 1e-6, bool flash_attn = false) : WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn) {} - virtual struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* context, int64_t context_img_len) = 0; @@ -1365,8 +1367,7 @@ namespace WAN { float eps = 1e-6, bool flash_attn = false) : WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {} - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* context, int64_t context_img_len) override { @@ -1390,7 +1391,7 @@ namespace WAN { k = norm_k->forward(ctx, k); auto v = v_proj->forward(ctx, context); // [N, n_context, dim] - x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim] return x; @@ -1415,8 +1416,7 @@ namespace WAN { } } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* context, int64_t context_img_len) override { @@ -1441,11 +1441,11 @@ namespace WAN { int64_t dim = x->ne[0]; int64_t context_txt_len = context->ne[1] - context_img_len; - context = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim] - auto context_img = ggml_view_3d(ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0); - auto context_txt = ggml_view_3d(ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]); - context_img = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim] - context_txt = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim] + context = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim] + auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0); + auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]); + context_img = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim] + context_txt = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim] auto q = q_proj->forward(ctx, x); q = norm_q->forward(ctx, q); @@ -1457,10 +1457,10 @@ namespace WAN { k_img = norm_k_img->forward(ctx, k_img); auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] - auto img_x = ggml_ext_attention_ext(ctx, backend, q, k_img, v_img, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] - x = ggml_ext_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] + auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] - x = ggml_add(ctx, x, img_x); + x = ggml_add(ctx->ggml_ctx, x, img_x); x = o_proj->forward(ctx, x); // [N, n_token, dim] return x; @@ -1534,8 +1534,7 @@ namespace WAN { blocks["ffn.2"] = std::shared_ptr(new Linear(ffn_dim, dim)); } - virtual struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* e, struct ggml_tensor* pe, @@ -1547,8 +1546,8 @@ namespace WAN { // return [N, n_token, dim] auto modulation = params["modulation"]; - e = ggml_add(ctx, e, modulation); // [N, 6, dim] or [N, T, 6, dim] - auto es = ggml_ext_chunk(ctx, e, 6, 1); // ([N, 1, dim], ...) or [N, T, 1, dim] + e = ggml_add(ctx->ggml_ctx, e, modulation); // [N, 6, dim] or [N, T, 6, dim] + auto es = ggml_ext_chunk(ctx->ggml_ctx, e, 6, 1); // ([N, 1, dim], ...) or [N, T, 1, dim] auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); @@ -1560,27 +1559,27 @@ namespace WAN { // self-attention auto y = norm1->forward(ctx, x); - y = ggml_add(ctx, y, modulate_mul(ctx, y, es[1])); - y = modulate_add(ctx, y, es[0]); - y = self_attn->forward(ctx, backend, y, pe); + y = ggml_add(ctx->ggml_ctx, y, modulate_mul(ctx->ggml_ctx, y, es[1])); + y = modulate_add(ctx->ggml_ctx, y, es[0]); + y = self_attn->forward(ctx, y, pe); - x = ggml_add(ctx, x, modulate_mul(ctx, y, es[2])); + x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[2])); // cross-attention - x = ggml_add(ctx, + x = ggml_add(ctx->ggml_ctx, x, - cross_attn->forward(ctx, backend, norm3->forward(ctx, x), context, context_img_len)); + cross_attn->forward(ctx, norm3->forward(ctx, x), context, context_img_len)); // ffn y = norm2->forward(ctx, x); - y = ggml_add(ctx, y, modulate_mul(ctx, y, es[4])); - y = modulate_add(ctx, y, es[3]); + y = ggml_add(ctx->ggml_ctx, y, modulate_mul(ctx->ggml_ctx, y, es[4])); + y = modulate_add(ctx->ggml_ctx, y, es[3]); y = ffn_0->forward(ctx, y); - y = ggml_gelu_inplace(ctx, y); + y = ggml_gelu_inplace(ctx->ggml_ctx, y); y = ffn_2->forward(ctx, y); - x = ggml_add(ctx, x, modulate_mul(ctx, y, es[5])); + x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5])); return x; } @@ -1611,8 +1610,7 @@ namespace WAN { blocks["after_proj"] = std::shared_ptr(new Linear(dim, dim)); } - std::pair forward(struct ggml_context* ctx, - ggml_backend_t backend, + std::pair forward(GGMLRunnerContext* ctx, struct ggml_tensor* c, struct ggml_tensor* x, struct ggml_tensor* e, @@ -1627,12 +1625,12 @@ namespace WAN { auto before_proj = std::dynamic_pointer_cast(blocks["before_proj"]); c = before_proj->forward(ctx, c); - c = ggml_add(ctx, c, x); + c = ggml_add(ctx->ggml_ctx, c, x); } auto after_proj = std::dynamic_pointer_cast(blocks["after_proj"]); - c = WanAttentionBlock::forward(ctx, backend, c, e, pe, context, context_img_len); + c = WanAttentionBlock::forward(ctx, c, e, pe, context, context_img_len); auto c_skip = after_proj->forward(ctx, c); return {c_skip, c}; @@ -1660,7 +1658,7 @@ namespace WAN { blocks["head"] = std::shared_ptr(new Linear(dim, out_dim)); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* e) { // x: [N, n_token, dim] @@ -1668,18 +1666,18 @@ namespace WAN { // return [N, n_token, out_dim] auto modulation = params["modulation"]; - e = ggml_reshape_4d(ctx, e, e->ne[0], 1, e->ne[1], e->ne[2]); // [N, 1, dim] or [N, T, 1, dim] - e = ggml_repeat_4d(ctx, e, e->ne[0], 2, e->ne[2], e->ne[3]); // [N, 2, dim] or [N, T, 2, dim] + e = ggml_reshape_4d(ctx->ggml_ctx, e, e->ne[0], 1, e->ne[1], e->ne[2]); // [N, 1, dim] or [N, T, 1, dim] + e = ggml_repeat_4d(ctx->ggml_ctx, e, e->ne[0], 2, e->ne[2], e->ne[3]); // [N, 2, dim] or [N, T, 2, dim] - e = ggml_add(ctx, e, modulation); // [N, 2, dim] or [N, T, 2, dim] - auto es = ggml_ext_chunk(ctx, e, 2, 1); // ([N, 1, dim], ...) or ([N, T, 1, dim], ...) + e = ggml_add(ctx->ggml_ctx, e, modulation); // [N, 2, dim] or [N, T, 2, dim] + auto es = ggml_ext_chunk(ctx->ggml_ctx, e, 2, 1); // ([N, 1, dim], ...) or ([N, T, 1, dim], ...) auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto head = std::dynamic_pointer_cast(blocks["head"]); x = norm->forward(ctx, x); - x = ggml_add(ctx, x, modulate_mul(ctx, x, es[1])); - x = modulate_add(ctx, x, es[0]); + x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, x, es[1])); + x = modulate_add(ctx->ggml_ctx, x, es[0]); x = head->forward(ctx, x); return x; } @@ -1708,15 +1706,15 @@ namespace WAN { blocks["proj.4"] = std::shared_ptr(new LayerNorm(out_dim)); } - struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* image_embeds) { if (flf_pos_embed_token_number > 0) { auto emb_pos = params["emb_pos"]; - auto a = ggml_ext_slice(ctx, image_embeds, 1, 0, emb_pos->ne[1]); - auto b = ggml_ext_slice(ctx, emb_pos, 1, 0, image_embeds->ne[1]); + auto a = ggml_ext_slice(ctx->ggml_ctx, image_embeds, 1, 0, emb_pos->ne[1]); + auto b = ggml_ext_slice(ctx->ggml_ctx, emb_pos, 1, 0, image_embeds->ne[1]); - image_embeds = ggml_add(ctx, a, b); + image_embeds = ggml_add(ctx->ggml_ctx, a, b); } auto proj_0 = std::dynamic_pointer_cast(blocks["proj.0"]); @@ -1726,7 +1724,7 @@ namespace WAN { auto x = proj_0->forward(ctx, image_embeds); x = proj_1->forward(ctx, x); - x = ggml_gelu_inplace(ctx, x); + x = ggml_gelu_inplace(ctx->ggml_ctx, x); x = proj_3->forward(ctx, x); x = proj_4->forward(ctx, x); @@ -1872,8 +1870,7 @@ namespace WAN { return x; } - struct ggml_tensor* forward_orig(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward_orig(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, @@ -1902,31 +1899,31 @@ namespace WAN { auto head = std::dynamic_pointer_cast(blocks["head"]); // patch_embedding - x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len] - x = ggml_reshape_3d(ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len] - x = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim] + x = patch_embedding->forward(ctx, x); // [N*dim, t_len, h_len, w_len] + x = ggml_reshape_3d(ctx->ggml_ctx, x, x->ne[0] * x->ne[1] * x->ne[2], x->ne[3] / N, N); // [N, dim, t_len*h_len*w_len] + x = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim] // time_embedding - auto e = ggml_ext_timestep_embedding(ctx, timestep, params.freq_dim); + auto e = ggml_ext_timestep_embedding(ctx->ggml_ctx, timestep, params.freq_dim); e = time_embedding_0->forward(ctx, e); - e = ggml_silu_inplace(ctx, e); + e = ggml_silu_inplace(ctx->ggml_ctx, e); e = time_embedding_2->forward(ctx, e); // [N, dim] or [N, T, dim] // time_projection - auto e0 = ggml_silu(ctx, e); + auto e0 = ggml_silu(ctx->ggml_ctx, e); e0 = time_projection_1->forward(ctx, e0); - e0 = ggml_reshape_4d(ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim] + e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim] context = text_embedding_0->forward(ctx, context); - context = ggml_gelu(ctx, context); + context = ggml_gelu(ctx->ggml_ctx, context); context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim] int64_t context_img_len = 0; if (clip_fea != nullptr) { if (params.model_type == "i2v") { auto img_emb = std::dynamic_pointer_cast(blocks["img_emb"]); - auto context_img = img_emb->forward(ctx, clip_fea); // [N, context_img_len, dim] - context = ggml_concat(ctx, context_img, context, 1); // [N, context_img_len + context_txt_len, dim] + auto context_img = img_emb->forward(ctx, clip_fea); // [N, context_img_len, dim] + context = ggml_concat(ctx->ggml_ctx, context_img, context, 1); // [N, context_img_len + context_txt_len, dim] } context_img_len = clip_fea->ne[1]; // 257 } @@ -1936,9 +1933,9 @@ namespace WAN { if (params.vace_layers > 0) { auto vace_patch_embedding = std::dynamic_pointer_cast(blocks["vace_patch_embedding"]); - c = vace_patch_embedding->forward(ctx, vace_context); // [N*dim, t_len, h_len, w_len] - c = ggml_reshape_3d(ctx, c, c->ne[0] * c->ne[1] * c->ne[2], c->ne[3] / N, N); // [N, dim, t_len*h_len*w_len] - c = ggml_ext_cont(ctx, ggml_ext_torch_permute(ctx, c, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim] + c = vace_patch_embedding->forward(ctx, vace_context); // [N*dim, t_len, h_len, w_len] + c = ggml_reshape_3d(ctx->ggml_ctx, c, c->ne[0] * c->ne[1] * c->ne[2], c->ne[3] / N, N); // [N, dim, t_len*h_len*w_len] + c = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, c, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim] } auto x_orig = x; @@ -1946,7 +1943,7 @@ namespace WAN { for (int i = 0; i < params.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["blocks." + std::to_string(i)]); - x = block->forward(ctx, backend, x, e0, pe, context, context_img_len); + x = block->forward(ctx, x, e0, pe, context, context_img_len); auto iter = params.vace_layers_mapping.find(i); if (iter != params.vace_layers_mapping.end()) { @@ -1954,11 +1951,11 @@ namespace WAN { auto vace_block = std::dynamic_pointer_cast(blocks["vace_blocks." + std::to_string(n)]); - auto result = vace_block->forward(ctx, backend, c, x_orig, e0, pe, context, context_img_len); + auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len); auto c_skip = result.first; c = result.second; - c_skip = ggml_scale(ctx, c_skip, vace_strength); - x = ggml_add(ctx, x, c_skip); + c_skip = ggml_scale(ctx->ggml_ctx, c_skip, vace_strength); + x = ggml_add(ctx->ggml_ctx, x, c_skip); } } @@ -1967,8 +1964,7 @@ namespace WAN { return x; } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, + struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, @@ -1993,27 +1989,27 @@ namespace WAN { int64_t T = x->ne[2]; int64_t C = x->ne[3]; - x = pad_to_patch_size(ctx, x); + x = pad_to_patch_size(ctx->ggml_ctx, x); int64_t t_len = ((T + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); int64_t h_len = ((H + (std::get<1>(params.patch_size) / 2)) / std::get<1>(params.patch_size)); int64_t w_len = ((W + (std::get<2>(params.patch_size) / 2)) / std::get<2>(params.patch_size)); if (time_dim_concat != nullptr) { - time_dim_concat = pad_to_patch_size(ctx, time_dim_concat); - x = ggml_concat(ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] + time_dim_concat = pad_to_patch_size(ctx->ggml_ctx, time_dim_concat); + x = ggml_concat(ctx->ggml_ctx, x, time_dim_concat, 2); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); } - auto out = forward_orig(ctx, backend, x, timestep, context, pe, clip_fea, vace_context, vace_strength, N); // [N, t_len*h_len*w_len, pt*ph*pw*C] + auto out = forward_orig(ctx, x, timestep, context, pe, clip_fea, vace_context, vace_strength, N); // [N, t_len*h_len*w_len, pt*ph*pw*C] - out = unpatchify(ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] + out = unpatchify(ctx->ggml_ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] // slice - out = ggml_ext_slice(ctx, out, 2, 0, T); // [N*C, T, H + pad_h, W + pad_w] - out = ggml_ext_slice(ctx, out, 1, 0, H); // [N*C, T, H, W + pad_w] - out = ggml_ext_slice(ctx, out, 0, 0, W); // [N*C, T, H, W] + out = ggml_ext_slice(ctx->ggml_ctx, out, 2, 0, T); // [N*C, T, H + pad_h, W + pad_w] + out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N*C, T, H, W + pad_w] + out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N*C, T, H, W] return out; } @@ -2183,8 +2179,9 @@ namespace WAN { x = ggml_concat(compute_ctx, x, c_concat, 3); } - struct ggml_tensor* out = wan.forward(compute_ctx, - runtime_backend, + auto runner_ctx = get_context(); + + struct ggml_tensor* out = wan.forward(&runner_ctx, x, timesteps, context, From 6ea46b4ee5492d7c039ce749aa46b4e97034f91e Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 1 Nov 2025 00:30:25 +0800 Subject: [PATCH 2/4] add Flash Attention enable control through GGMLRunnerContext --- common.hpp | 21 +++++++--------- diffusion_model.hpp | 42 ++++++++++++++++++++++---------- flux.hpp | 30 +++++++---------------- ggml_extend.hpp | 12 ++++++++-- mmdit.hpp | 51 +++++++++++++++------------------------ qwen_image.hpp | 25 +++++++------------ qwenvl.hpp | 2 +- rope.hpp | 3 +-- stable-diffusion.cpp | 22 +++++++---------- unet.hpp | 9 ++++--- wan.hpp | 57 +++++++++++++++++--------------------------- 11 files changed, 120 insertions(+), 154 deletions(-) diff --git a/common.hpp b/common.hpp index cb07e8dc0..45d3de396 100644 --- a/common.hpp +++ b/common.hpp @@ -281,19 +281,16 @@ class CrossAttention : public GGMLBlock { int64_t context_dim; int64_t n_head; int64_t d_head; - bool flash_attn; public: CrossAttention(int64_t query_dim, int64_t context_dim, int64_t n_head, - int64_t d_head, - bool flash_attn = false) + int64_t d_head) : n_head(n_head), d_head(d_head), query_dim(query_dim), - context_dim(context_dim), - flash_attn(flash_attn) { + context_dim(context_dim) { int64_t inner_dim = d_head * n_head; blocks["to_q"] = std::shared_ptr(new Linear(query_dim, inner_dim, false)); @@ -324,7 +321,7 @@ class CrossAttention : public GGMLBlock { auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, false, flash_attn); // [N, n_token, inner_dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, inner_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] return x; @@ -342,16 +339,15 @@ class BasicTransformerBlock : public GGMLBlock { int64_t n_head, int64_t d_head, int64_t context_dim, - bool ff_in = false, - bool flash_attn = false) + bool ff_in = false) : n_head(n_head), d_head(d_head), ff_in(ff_in) { // disable_self_attn is always False // disable_temporal_crossattention is always False // switch_temporal_ca_to_sa is always False // inner_dim is always None or equal to dim // gated_ff is always True - blocks["attn1"] = std::shared_ptr(new CrossAttention(dim, dim, n_head, d_head, flash_attn)); - blocks["attn2"] = std::shared_ptr(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn)); + blocks["attn1"] = std::shared_ptr(new CrossAttention(dim, dim, n_head, d_head)); + blocks["attn2"] = std::shared_ptr(new CrossAttention(dim, context_dim, n_head, d_head)); blocks["ff"] = std::shared_ptr(new FeedForward(dim, dim)); blocks["norm1"] = std::shared_ptr(new LayerNorm(dim)); blocks["norm2"] = std::shared_ptr(new LayerNorm(dim)); @@ -418,8 +414,7 @@ class SpatialTransformer : public GGMLBlock { int64_t n_head, int64_t d_head, int64_t depth, - int64_t context_dim, - bool flash_attn = false) + int64_t context_dim) : in_channels(in_channels), n_head(n_head), d_head(d_head), @@ -433,7 +428,7 @@ class SpatialTransformer : public GGMLBlock { for (int i = 0; i < depth; i++) { std::string name = "transformer_blocks." + std::to_string(i); - blocks[name] = std::shared_ptr(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn)); + blocks[name] = std::shared_ptr(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false)); } blocks["proj_out"] = std::shared_ptr(new Conv2d(inner_dim, in_channels, {1, 1})); diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 94b29bf11..fd63050e6 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -36,6 +36,7 @@ struct DiffusionModel { virtual void get_param_tensors(std::map& tensors) = 0; virtual size_t get_params_buffer_size() = 0; virtual int64_t get_adm_in_channels() = 0; + virtual void set_flash_attn_enabled(bool enabled) = 0; }; struct UNetModel : public DiffusionModel { @@ -44,9 +45,8 @@ struct UNetModel : public DiffusionModel { UNetModel(ggml_backend_t backend, bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}, - SDVersion version = VERSION_SD1, - bool flash_attn = false) - : unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn) { + SDVersion version = VERSION_SD1) + : unet(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version) { } std::string get_desc() override { @@ -77,6 +77,10 @@ struct UNetModel : public DiffusionModel { return unet.unet.adm_in_channels; } + void set_flash_attn_enabled(bool enabled) { + unet.set_flash_attention_enabled(enabled); + } + void compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -98,9 +102,8 @@ struct MMDiTModel : public DiffusionModel { MMDiTModel(ggml_backend_t backend, bool offload_params_to_cpu, - bool flash_attn = false, const String2GGMLType& tensor_types = {}) - : mmdit(backend, offload_params_to_cpu, flash_attn, tensor_types, "model.diffusion_model") { + : mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") { } std::string get_desc() override { @@ -131,6 +134,10 @@ struct MMDiTModel : public DiffusionModel { return 768 + 1280; } + void set_flash_attn_enabled(bool enabled) { + mmdit.set_flash_attention_enabled(enabled); + } + void compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -153,9 +160,8 @@ struct FluxModel : public DiffusionModel { bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}, SDVersion version = VERSION_FLUX, - bool flash_attn = false, bool use_mask = false) - : flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) { + : flux(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model", version, use_mask) { } std::string get_desc() override { @@ -186,6 +192,10 @@ struct FluxModel : public DiffusionModel { return 768; } + void set_flash_attn_enabled(bool enabled) { + flux.set_flash_attention_enabled(enabled); + } + void compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -213,9 +223,8 @@ struct WanModel : public DiffusionModel { bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}, const std::string prefix = "model.diffusion_model", - SDVersion version = VERSION_WAN2, - bool flash_attn = false) - : prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) { + SDVersion version = VERSION_WAN2) + : prefix(prefix), wan(backend, offload_params_to_cpu, tensor_types, prefix, version) { } std::string get_desc() override { @@ -246,6 +255,10 @@ struct WanModel : public DiffusionModel { return 768; } + void set_flash_attn_enabled(bool enabled) { + wan.set_flash_attention_enabled(enabled); + } + void compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, @@ -272,9 +285,8 @@ struct QwenImageModel : public DiffusionModel { bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}, const std::string prefix = "model.diffusion_model", - SDVersion version = VERSION_QWEN_IMAGE, - bool flash_attn = false) - : prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) { + SDVersion version = VERSION_QWEN_IMAGE) + : prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version) { } std::string get_desc() override { @@ -305,6 +317,10 @@ struct QwenImageModel : public DiffusionModel { return 768; } + void set_flash_attn_enabled(bool enabled) { + qwen_image.set_flash_attention_enabled(enabled); + } + void compute(int n_threads, DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, diff --git a/flux.hpp b/flux.hpp index ce0cf4763..eb149256f 100644 --- a/flux.hpp +++ b/flux.hpp @@ -85,13 +85,11 @@ namespace Flux { struct SelfAttention : public GGMLBlock { public: int64_t num_heads; - bool flash_attn; public: SelfAttention(int64_t dim, int64_t num_heads = 8, - bool qkv_bias = false, - bool flash_attn = false) + bool qkv_bias = false) : num_heads(num_heads) { int64_t head_dim = dim / num_heads; blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); @@ -129,7 +127,7 @@ namespace Flux { // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] - x = Rope::attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] + x = Rope::attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask); // [N, n_token, dim] x = post_attention(ctx, x); // [N, n_token, dim] return x; } @@ -198,7 +196,6 @@ namespace Flux { } struct DoubleStreamBlock : public GGMLBlock { - bool flash_attn; bool prune_mod; int idx = 0; @@ -208,15 +205,14 @@ namespace Flux { float mlp_ratio, int idx = 0, bool qkv_bias = false, - bool flash_attn = false, bool prune_mod = false) - : idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) { + : idx(idx), prune_mod(prune_mod) { int64_t mlp_hidden_dim = hidden_size * mlp_ratio; if (!prune_mod) { blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); } blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["img_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); @@ -227,7 +223,7 @@ namespace Flux { blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); } blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); + blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["txt_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); @@ -317,7 +313,7 @@ namespace Flux { auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = Rope::attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, @@ -365,7 +361,6 @@ namespace Flux { int64_t num_heads; int64_t hidden_size; int64_t mlp_hidden_dim; - bool flash_attn; bool prune_mod; int idx = 0; @@ -375,9 +370,8 @@ namespace Flux { float mlp_ratio = 4.0f, int idx = 0, float qk_scale = 0.f, - bool flash_attn = false, bool prune_mod = false) - : hidden_size(hidden_size), num_heads(num_heads), idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) { + : hidden_size(hidden_size), num_heads(num_heads), idx(idx), prune_mod(prune_mod) { int64_t head_dim = hidden_size / num_heads; float scale = qk_scale; if (scale <= 0.f) { @@ -451,7 +445,7 @@ namespace Flux { auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k); - auto attn = Rope::attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] + auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size] auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, ggml_gelu_inplace(ctx->ggml_ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] @@ -689,7 +683,6 @@ namespace Flux { int theta = 10000; bool qkv_bias = true; bool guidance_embed = true; - bool flash_attn = true; int64_t in_dim = 64; ChromaRadianceParams chroma_radiance_params; }; @@ -728,7 +721,6 @@ namespace Flux { params.mlp_ratio, i, params.qkv_bias, - params.flash_attn, params.is_chroma); } @@ -738,7 +730,6 @@ namespace Flux { params.mlp_ratio, i, 0.f, - params.flash_attn, params.is_chroma); } @@ -1127,11 +1118,9 @@ namespace Flux { const String2GGMLType& tensor_types = {}, const std::string prefix = "", SDVersion version = VERSION_FLUX, - bool flash_attn = false, bool use_mask = false) : GGMLRunner(backend, offload_params_to_cpu), version(version), use_mask(use_mask) { flux_params.version = version; - flux_params.flash_attn = flash_attn; flux_params.guidance_embed = false; flux_params.depth = 0; flux_params.depth_single_blocks = 0; @@ -1427,8 +1416,7 @@ namespace Flux { tensor_types, "model.diffusion_model", VERSION_CHROMA_RADIANCE, - false, - true); + false); flux->alloc_params_buffer(); std::map tensors; diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 13c5e2be0..20ed824b9 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1157,8 +1157,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context struct ggml_tensor* mask = nullptr, bool diag_mask_inf = false, bool skip_reshape = false, - bool flash_attn = false, // avoid overflow - float kv_scale = 1.0f) { + bool flash_attn = false, + float kv_scale = 1.0f ) { // avoid overflow int64_t L_q; int64_t L_k; int64_t C; @@ -1465,6 +1465,7 @@ typedef std::map String2GGMLType; struct GGMLRunnerContext { ggml_backend_t backend = nullptr; ggml_context* ggml_ctx = nullptr; + bool flash_attn_enabled = false; }; struct GGMLRunner { @@ -1493,6 +1494,8 @@ struct GGMLRunner { std::map cache_tensor_map; // name -> tensor const std::string final_result_name = "ggml_runner_final_result_tensor"; + bool flash_attn_enabled = false; + void alloc_params_ctx() { struct ggml_init_params params; params.mem_size = static_cast(MAX_PARAMS_TENSOR_NUM * ggml_tensor_overhead()); @@ -1753,6 +1756,7 @@ struct GGMLRunner { GGMLRunnerContext runner_ctx; runner_ctx.ggml_ctx = compute_ctx; runner_ctx.backend = runtime_backend; + runner_ctx.flash_attn_enabled = flash_attn_enabled; return runner_ctx; } @@ -1876,6 +1880,10 @@ struct GGMLRunner { free_compute_buffer(); } } + + void set_flash_attention_enabled(bool enabled) { + flash_attn_enabled = enabled; + } }; class GGMLBlock { diff --git a/mmdit.hpp b/mmdit.hpp index 2165c1483..18e921d49 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -149,16 +149,14 @@ class SelfAttention : public GGMLBlock { int64_t num_heads; bool pre_only; std::string qk_norm; - bool flash_attn; public: SelfAttention(int64_t dim, int64_t num_heads = 8, std::string qk_norm = "", bool qkv_bias = false, - bool pre_only = false, - bool flash_attn = false) - : num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm), flash_attn(flash_attn) { + bool pre_only = false) + : num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) { int64_t d_head = dim / num_heads; blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); if (!pre_only) { @@ -209,7 +207,7 @@ class SelfAttention : public GGMLBlock { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { auto qkv = pre_attention(ctx, x); - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, true); // [N, n_token, dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = post_attention(ctx, x); // [N, n_token, dim] return x; } @@ -235,7 +233,6 @@ struct DismantledBlock : public GGMLBlock { int64_t num_heads; bool pre_only; bool self_attn; - bool flash_attn; public: DismantledBlock(int64_t hidden_size, @@ -244,17 +241,16 @@ struct DismantledBlock : public GGMLBlock { std::string qk_norm = "", bool qkv_bias = false, bool pre_only = false, - bool self_attn = false, - bool flash_attn = false) + bool self_attn = false) : num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) { // rmsnorm is always Flase // scale_mod_only is always Flase // swiglu is always Flase blocks["norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); - blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only, flash_attn)); + blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only)); if (self_attn) { - blocks["attn2"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false, flash_attn)); + blocks["attn2"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false)); } if (!pre_only) { @@ -439,8 +435,8 @@ struct DismantledBlock : public GGMLBlock { auto qkv2 = std::get<1>(qkv_intermediates); auto intermediates = std::get<2>(qkv_intermediates); - auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] - auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] + auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] + auto attn2_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv2[0], qkv2[1], qkv2[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = post_attention_x(ctx, attn_out, attn2_out, @@ -456,7 +452,7 @@ struct DismantledBlock : public GGMLBlock { auto qkv = qkv_intermediates.first; auto intermediates = qkv_intermediates.second; - auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] + auto attn_out = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = post_attention(ctx, attn_out, intermediates[0], @@ -471,7 +467,6 @@ struct DismantledBlock : public GGMLBlock { __STATIC_INLINE__ std::pair block_mixing(GGMLRunnerContext* ctx, - bool flash_attn, struct ggml_tensor* context, struct ggml_tensor* x, struct ggml_tensor* c, @@ -501,7 +496,7 @@ block_mixing(GGMLRunnerContext* ctx, qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1)); } - auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, flash_attn); // [N, n_context + n_token, hidden_size] + auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size] attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] auto context_attn = ggml_view_3d(ctx->ggml_ctx, attn, @@ -535,7 +530,7 @@ block_mixing(GGMLRunnerContext* ctx, } if (x_block->self_attn) { - auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size] + auto attn2 = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, hidden_size] x = x_block->post_attention_x(ctx, x_attn, @@ -560,8 +555,6 @@ block_mixing(GGMLRunnerContext* ctx, } struct JointBlock : public GGMLBlock { - bool flash_attn; - public: JointBlock(int64_t hidden_size, int64_t num_heads, @@ -569,11 +562,9 @@ struct JointBlock : public GGMLBlock { std::string qk_norm = "", bool qkv_bias = false, bool pre_only = false, - bool self_attn_x = false, - bool flash_attn = false) - : flash_attn(flash_attn) { - blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false, flash_attn)); - blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x, flash_attn)); + bool self_attn_x = false) { + blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false)); + blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x)); } std::pair forward(GGMLRunnerContext* ctx, @@ -583,7 +574,7 @@ struct JointBlock : public GGMLBlock { auto context_block = std::dynamic_pointer_cast(blocks["context_block"]); auto x_block = std::dynamic_pointer_cast(blocks["x_block"]); - return block_mixing(ctx, flash_attn, context, x, c, context_block, x_block); + return block_mixing(ctx, context, x, c, context_block, x_block); } }; @@ -641,7 +632,6 @@ struct MMDiT : public GGMLBlock { int64_t context_embedder_out_dim = 1536; int64_t hidden_size; std::string qk_norm; - bool flash_attn = false; void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") override { enum ggml_type wtype = GGML_TYPE_F32; @@ -649,8 +639,7 @@ struct MMDiT : public GGMLBlock { } public: - MMDiT(bool flash_attn = false, const String2GGMLType& tensor_types = {}) - : flash_attn(flash_attn) { + MMDiT(const String2GGMLType& tensor_types = {}) { // input_size is always None // learn_sigma is always False // register_length is alwalys 0 @@ -718,8 +707,7 @@ struct MMDiT : public GGMLBlock { qk_norm, true, i == depth - 1, - i <= d_self, - flash_attn)); + i <= d_self)); } blocks["final_layer"] = std::shared_ptr(new FinalLayer(hidden_size, patch_size, out_channels)); @@ -864,10 +852,9 @@ struct MMDiTRunner : public GGMLRunner { MMDiTRunner(ggml_backend_t backend, bool offload_params_to_cpu, - bool flash_attn, const String2GGMLType& tensor_types = {}, const std::string prefix = "") - : GGMLRunner(backend, offload_params_to_cpu), mmdit(flash_attn, tensor_types) { + : GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) { mmdit.init(params_ctx, tensor_types, prefix); } @@ -966,7 +953,7 @@ struct MMDiTRunner : public GGMLRunner { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F16; - std::shared_ptr mmdit = std::make_shared(backend, false, false); + std::shared_ptr mmdit = std::make_shared(backend, false); { LOG_INFO("loading from '%s'", file_path.c_str()); diff --git a/qwen_image.hpp b/qwen_image.hpp index e442cfda4..ea5b3d310 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -65,7 +65,6 @@ namespace Qwen { struct QwenImageAttention : public GGMLBlock { protected: int64_t dim_head; - bool flash_attn; public: QwenImageAttention(int64_t query_dim, @@ -75,9 +74,8 @@ namespace Qwen { int64_t out_context_dim = 0, bool bias = true, bool out_bias = true, - float eps = 1e-6, - bool flash_attn = false) - : dim_head(dim_head), flash_attn(flash_attn) { + float eps = 1e-6) + : dim_head(dim_head) { int64_t inner_dim = out_dim > 0 ? out_dim : dim_head * num_heads; out_dim = out_dim > 0 ? out_dim : query_dim; out_context_dim = out_context_dim > 0 ? out_context_dim : query_dim; @@ -160,7 +158,7 @@ namespace Qwen { auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = Rope::attention(ctx, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, @@ -193,8 +191,7 @@ namespace Qwen { QwenImageTransformerBlock(int64_t dim, int64_t num_attention_heads, int64_t attention_head_dim, - float eps = 1e-6, - bool flash_attn = false) { + float eps = 1e-6) { // img_mod.0 is nn.SiLU() blocks["img_mod.1"] = std::shared_ptr(new Linear(dim, 6 * dim, true)); @@ -216,8 +213,7 @@ namespace Qwen { 0, // out_context-dim true, // bias true, // out_bias - eps, - flash_attn)); + eps)); } virtual std::pair forward(GGMLRunnerContext* ctx, @@ -325,7 +321,6 @@ namespace Qwen { float theta = 10000; std::vector axes_dim = {16, 56, 56}; int64_t axes_dim_sum = 128; - bool flash_attn = false; }; class QwenImageModel : public GGMLBlock { @@ -347,8 +342,7 @@ namespace Qwen { auto block = std::shared_ptr(new QwenImageTransformerBlock(inner_dim, params.num_attention_heads, params.attention_head_dim, - 1e-6f, - params.flash_attn)); + 1e-6f)); blocks["transformer_blocks." + std::to_string(i)] = block; } @@ -510,10 +504,8 @@ namespace Qwen { bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}, const std::string prefix = "", - SDVersion version = VERSION_QWEN_IMAGE, - bool flash_attn = false) + SDVersion version = VERSION_QWEN_IMAGE) : GGMLRunner(backend, offload_params_to_cpu) { - qwen_image_params.flash_attn = flash_attn; qwen_image_params.num_layers = 0; for (auto pair : tensor_types) { std::string tensor_name = pair.first; @@ -669,8 +661,7 @@ namespace Qwen { false, tensor_types, "model.diffusion_model", - VERSION_QWEN_IMAGE, - true); + VERSION_QWEN_IMAGE); qwen_image->alloc_params_buffer(); std::map tensors; diff --git a/qwenvl.hpp b/qwenvl.hpp index d430742f6..8918978d8 100644 --- a/qwenvl.hpp +++ b/qwenvl.hpp @@ -525,7 +525,7 @@ namespace Qwen { auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] - x = Rope::attention(ctx, q, k, v, pe, mask, false, 1.f, false); // [N, n_token, hidden_size] + x = Rope::attention(ctx, q, k, v, pe, mask, 1.f, false); // [N, n_token, hidden_size] x = proj->forward(ctx, x); // [N, n_token, hidden_size] return x; diff --git a/rope.hpp b/rope.hpp index 05f04d0d5..bd1dfad5d 100644 --- a/rope.hpp +++ b/rope.hpp @@ -392,7 +392,6 @@ namespace Rope { struct ggml_tensor* v, struct ggml_tensor* pe, struct ggml_tensor* mask, - bool flash_attn, float kv_scale = 1.0f, bool rope_interleaved = true) { // q,k,v: [N, L, n_head, d_head] @@ -401,7 +400,7 @@ namespace Rope { q = apply_rope(ctx->ggml_ctx, q, pe, rope_interleaved); // [N*n_head, L, d_head] k = apply_rope(ctx->ggml_ctx, k, pe, rope_interleaved); // [N*n_head, L, d_head] - auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, false, true, flash_attn, kv_scale); // [N, L, n_head*d_head] + auto x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, v->ne[1], mask, false, true, ctx->flash_attn_enabled, kv_scale); // [N, L, n_head*d_head] return x; } }; // namespace Rope diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 9552cfc4d..7c7434522 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -341,16 +341,12 @@ class StableDiffusionGGML { LOG_INFO("CLIP: Using CPU backend"); clip_backend = ggml_backend_cpu_init(); } - if (sd_ctx_params->diffusion_flash_attn) { - LOG_INFO("Using flash attention in the diffusion model"); - } if (sd_version_is_sd3(version)) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, model_loader.tensor_storages_types); diffusion_model = std::make_shared(backend, offload_params_to_cpu, - sd_ctx_params->diffusion_flash_attn, model_loader.tensor_storages_types); } else if (sd_version_is_flux(version)) { bool is_chroma = false; @@ -384,7 +380,6 @@ class StableDiffusionGGML { offload_params_to_cpu, model_loader.tensor_storages_types, version, - sd_ctx_params->diffusion_flash_attn, sd_ctx_params->chroma_use_dit_mask); } else if (sd_version_is_wan(version)) { cond_stage_model = std::make_shared(clip_backend, @@ -397,15 +392,13 @@ class StableDiffusionGGML { offload_params_to_cpu, model_loader.tensor_storages_types, "model.diffusion_model", - version, - sd_ctx_params->diffusion_flash_attn); + version); if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) { high_noise_diffusion_model = std::make_shared(backend, offload_params_to_cpu, model_loader.tensor_storages_types, "model.high_noise_diffusion_model", - version, - sd_ctx_params->diffusion_flash_attn); + version); } if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") { clip_vision = std::make_shared(backend, @@ -428,8 +421,7 @@ class StableDiffusionGGML { offload_params_to_cpu, model_loader.tensor_storages_types, "model.diffusion_model", - version, - sd_ctx_params->diffusion_flash_attn); + version); } else { // SD1.x SD2.x SDXL if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { cond_stage_model = std::make_shared(clip_backend, @@ -448,14 +440,18 @@ class StableDiffusionGGML { diffusion_model = std::make_shared(backend, offload_params_to_cpu, model_loader.tensor_storages_types, - version, - sd_ctx_params->diffusion_flash_attn); + version); if (sd_ctx_params->diffusion_conv_direct) { LOG_INFO("Using Conv2d direct in the diffusion model"); std::dynamic_pointer_cast(diffusion_model)->unet.enable_conv2d_direct(); } } + if (sd_ctx_params->diffusion_flash_attn) { + LOG_INFO("Using flash attention in the diffusion model"); + diffusion_model->set_flash_attn_enabled(true); + } + cond_stage_model->alloc_params_buffer(); cond_stage_model->get_param_tensors(tensors); diff --git a/unet.hpp b/unet.hpp index 80baa4b45..5153e5ed5 100644 --- a/unet.hpp +++ b/unet.hpp @@ -183,7 +183,7 @@ class UnetModelBlock : public GGMLBlock { int model_channels = 320; int adm_in_channels = 2816; // only for VERSION_SDXL/SVD - UnetModelBlock(SDVersion version = VERSION_SD1, const String2GGMLType& tensor_types = {}, bool flash_attn = false) + UnetModelBlock(SDVersion version = VERSION_SD1, const String2GGMLType& tensor_types = {}) : version(version) { if (sd_version_is_sd2(version)) { context_dim = 1024; @@ -251,7 +251,7 @@ class UnetModelBlock : public GGMLBlock { if (version == VERSION_SVD) { return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim); } else { - return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, flash_attn); + return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim); } }; @@ -583,9 +583,8 @@ struct UNetModelRunner : public GGMLRunner { bool offload_params_to_cpu, const String2GGMLType& tensor_types, const std::string prefix, - SDVersion version = VERSION_SD1, - bool flash_attn = false) - : GGMLRunner(backend, offload_params_to_cpu), unet(version, tensor_types, flash_attn) { + SDVersion version = VERSION_SD1) + : GGMLRunner(backend, offload_params_to_cpu), unet(version, tensor_types) { unet.init(params_ctx, tensor_types, prefix); } diff --git a/wan.hpp b/wan.hpp index 2815ee62b..b29d73a40 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1287,15 +1287,13 @@ namespace WAN { public: int64_t num_heads; int64_t head_dim; - bool flash_attn; public: WanSelfAttention(int64_t dim, int64_t num_heads, bool qk_norm = true, - float eps = 1e-6, - bool flash_attn = false) - : num_heads(num_heads), flash_attn(flash_attn) { + float eps = 1e-6) + : num_heads(num_heads) { head_dim = dim / num_heads; blocks["q"] = std::shared_ptr(new Linear(dim, dim)); blocks["k"] = std::shared_ptr(new Linear(dim, dim)); @@ -1338,7 +1336,7 @@ namespace WAN { k = ggml_reshape_4d(ctx->ggml_ctx, k, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] v = ggml_reshape_4d(ctx->ggml_ctx, v, head_dim, num_heads, n_token, N); // [N, n_token, n_head, d_head] - x = Rope::attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, dim] + x = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim] return x; @@ -1350,9 +1348,8 @@ namespace WAN { WanCrossAttention(int64_t dim, int64_t num_heads, bool qk_norm = true, - float eps = 1e-6, - bool flash_attn = false) - : WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn) {} + float eps = 1e-6) + : WanSelfAttention(dim, num_heads, qk_norm, eps) {} virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* context, @@ -1364,9 +1361,8 @@ namespace WAN { WanT2VCrossAttention(int64_t dim, int64_t num_heads, bool qk_norm = true, - float eps = 1e-6, - bool flash_attn = false) - : WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) {} + float eps = 1e-6) + : WanCrossAttention(dim, num_heads, qk_norm, eps) {} struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, struct ggml_tensor* context, @@ -1391,7 +1387,7 @@ namespace WAN { k = norm_k->forward(ctx, k); auto v = v_proj->forward(ctx, context); // [N, n_context, dim] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim] return x; @@ -1403,9 +1399,8 @@ namespace WAN { WanI2VCrossAttention(int64_t dim, int64_t num_heads, bool qk_norm = true, - float eps = 1e-6, - bool flash_attn = false) - : WanCrossAttention(dim, num_heads, qk_norm, eps, flash_attn) { + float eps = 1e-6) + : WanCrossAttention(dim, num_heads, qk_norm, eps) { blocks["k_img"] = std::shared_ptr(new Linear(dim, dim)); blocks["v_img"] = std::shared_ptr(new Linear(dim, dim)); @@ -1457,8 +1452,8 @@ namespace WAN { k_img = norm_k_img->forward(ctx, k_img); auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] - auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, flash_attn); // [N, n_token, dim] + auto img_x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k_img, v_img, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] x = ggml_add(ctx->ggml_ctx, x, img_x); @@ -1511,20 +1506,19 @@ namespace WAN { int64_t num_heads, bool qk_norm = true, bool cross_attn_norm = false, - float eps = 1e-6, - bool flash_attn = false) + float eps = 1e-6) : dim(dim) { blocks["norm1"] = std::shared_ptr(new LayerNorm(dim, eps, false)); - blocks["self_attn"] = std::shared_ptr(new WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn)); + blocks["self_attn"] = std::shared_ptr(new WanSelfAttention(dim, num_heads, qk_norm, eps)); if (cross_attn_norm) { blocks["norm3"] = std::shared_ptr(new LayerNorm(dim, eps, true)); } else { blocks["norm3"] = std::shared_ptr(new Identity()); } if (t2v_cross_attn) { - blocks["cross_attn"] = std::shared_ptr(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn)); + blocks["cross_attn"] = std::shared_ptr(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps)); } else { - blocks["cross_attn"] = std::shared_ptr(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn)); + blocks["cross_attn"] = std::shared_ptr(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps)); } blocks["norm2"] = std::shared_ptr(new LayerNorm(dim, eps, false)); @@ -1601,9 +1595,8 @@ namespace WAN { bool qk_norm = true, bool cross_attn_norm = false, float eps = 1e-6, - int block_id = 0, - bool flash_attn = false) - : WanAttentionBlock(t2v_cross_attn, dim, ffn_dim, num_heads, qk_norm, cross_attn_norm, eps, flash_attn), block_id(block_id) { + int block_id = 0) + : WanAttentionBlock(t2v_cross_attn, dim, ffn_dim, num_heads, qk_norm, cross_attn_norm, eps), block_id(block_id) { if (block_id == 0) { blocks["before_proj"] = std::shared_ptr(new Linear(dim, dim)); } @@ -1755,7 +1748,6 @@ namespace WAN { // wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24 std::vector axes_dim = {44, 42, 42}; int64_t axes_dim_sum = 128; - bool flash_attn = false; }; class Wan : public GGMLBlock { @@ -1790,8 +1782,7 @@ namespace WAN { params.num_heads, params.qk_norm, params.cross_attn_norm, - params.eps, - params.flash_attn)); + params.eps)); blocks["blocks." + std::to_string(i)] = block; } @@ -1813,8 +1804,7 @@ namespace WAN { params.qk_norm, params.cross_attn_norm, params.eps, - i, - params.flash_attn)); + i)); blocks["vace_blocks." + std::to_string(i)] = block; } @@ -2027,10 +2017,8 @@ namespace WAN { bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}, const std::string prefix = "", - SDVersion version = VERSION_WAN2, - bool flash_attn = false) + SDVersion version = VERSION_WAN2) : GGMLRunner(backend, offload_params_to_cpu) { - wan_params.flash_attn = flash_attn; wan_params.num_layers = 0; for (auto pair : tensor_types) { std::string tensor_name = pair.first; @@ -2278,8 +2266,7 @@ namespace WAN { false, tensor_types, "model.diffusion_model", - VERSION_WAN2_2_TI2V, - true); + VERSION_WAN2_2_TI2V); wan->alloc_params_buffer(); std::map tensors; From 9e281b0c45487e55aecec46b9bf2bfd7d1f81d6d Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 1 Nov 2025 00:40:52 +0800 Subject: [PATCH 3/4] add conv2d_direct enable control through GGMLRunnerContext --- control.hpp | 11 ----------- esrgan.hpp | 13 ------------- ggml_extend.hpp | 14 ++++++++------ stable-diffusion.cpp | 8 ++++---- tae.hpp | 11 ----------- unet.hpp | 12 ------------ upscaler.cpp | 2 +- vae.hpp | 12 ------------ 8 files changed, 13 insertions(+), 70 deletions(-) diff --git a/control.hpp b/control.hpp index 27eee7b27..72886dd02 100644 --- a/control.hpp +++ b/control.hpp @@ -324,17 +324,6 @@ struct ControlNet : public GGMLRunner { control_net.init(params_ctx, tensor_types, ""); } - void enable_conv2d_direct() { - std::vector blocks; - control_net.get_all_blocks(blocks); - for (auto block : blocks) { - if (block->get_desc() == "Conv2d") { - auto conv_block = (Conv2d*)block; - conv_block->enable_direct(); - } - } - } - ~ControlNet() override { free_control_ctx(); } diff --git a/esrgan.hpp b/esrgan.hpp index 10fbc0643..5a24436cb 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -161,19 +161,6 @@ struct ESRGAN : public GGMLRunner { // rrdb_net will be created in load_from_file } - void enable_conv2d_direct() { - if (!rrdb_net) - return; - std::vector blocks; - rrdb_net->get_all_blocks(blocks); - for (auto block : blocks) { - if (block->get_desc() == "Conv2d") { - auto conv_block = (Conv2d*)block; - conv_block->enable_direct(); - } - } - } - std::string get_desc() override { return "esrgan"; } diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 20ed824b9..61e15c337 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1466,6 +1466,7 @@ struct GGMLRunnerContext { ggml_backend_t backend = nullptr; ggml_context* ggml_ctx = nullptr; bool flash_attn_enabled = false; + bool conv2d_direct_enabled = false; }; struct GGMLRunner { @@ -1495,6 +1496,7 @@ struct GGMLRunner { const std::string final_result_name = "ggml_runner_final_result_tensor"; bool flash_attn_enabled = false; + bool conv2d_direct_enabled = false; void alloc_params_ctx() { struct ggml_init_params params; @@ -1757,6 +1759,7 @@ struct GGMLRunner { runner_ctx.ggml_ctx = compute_ctx; runner_ctx.backend = runtime_backend; runner_ctx.flash_attn_enabled = flash_attn_enabled; + runner_ctx.conv2d_direct_enabled = conv2d_direct_enabled; return runner_ctx; } @@ -1884,6 +1887,10 @@ struct GGMLRunner { void set_flash_attention_enabled(bool enabled) { flash_attn_enabled = enabled; } + + void set_conv2d_direct_enabled(bool enabled) { + conv2d_direct_enabled = enabled; + } }; class GGMLBlock { @@ -2084,7 +2091,6 @@ class Conv2d : public UnaryBlock { std::pair padding; std::pair dilation; bool bias; - bool direct = false; float scale = 1.f; void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") override { @@ -2112,10 +2118,6 @@ class Conv2d : public UnaryBlock { dilation(dilation), bias(bias) {} - void enable_direct() { - direct = true; - } - void set_scale(float scale_value) { scale = scale_value; } @@ -2140,7 +2142,7 @@ class Conv2d : public UnaryBlock { padding.first, dilation.second, dilation.first, - direct, + ctx->conv2d_direct_enabled, scale); } }; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 7c7434522..056825c90 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -443,7 +443,7 @@ class StableDiffusionGGML { version); if (sd_ctx_params->diffusion_conv_direct) { LOG_INFO("Using Conv2d direct in the diffusion model"); - std::dynamic_pointer_cast(diffusion_model)->unet.enable_conv2d_direct(); + std::dynamic_pointer_cast(diffusion_model)->unet.set_conv2d_direct_enabled(true); } } @@ -496,7 +496,7 @@ class StableDiffusionGGML { version); if (sd_ctx_params->vae_conv_direct) { LOG_INFO("Using Conv2d direct in the vae model"); - first_stage_model->enable_conv2d_direct(); + first_stage_model->set_conv2d_direct_enabled(true); } if (version == VERSION_SDXL && (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) { @@ -518,7 +518,7 @@ class StableDiffusionGGML { version); if (sd_ctx_params->vae_conv_direct) { LOG_INFO("Using Conv2d direct in the tae model"); - tae_first_stage->enable_conv2d_direct(); + tae_first_stage->set_conv2d_direct_enabled(true); } } // first_stage_model->get_param_tensors(tensors, "first_stage_model."); @@ -537,7 +537,7 @@ class StableDiffusionGGML { version); if (sd_ctx_params->diffusion_conv_direct) { LOG_INFO("Using Conv2d direct in the control net"); - control_net->enable_conv2d_direct(); + control_net->set_conv2d_direct_enabled(true); } } diff --git a/tae.hpp b/tae.hpp index 8069de0f2..21617b3f1 100644 --- a/tae.hpp +++ b/tae.hpp @@ -207,17 +207,6 @@ struct TinyAutoEncoder : public GGMLRunner { taesd.init(params_ctx, tensor_types, prefix); } - void enable_conv2d_direct() { - std::vector blocks; - taesd.get_all_blocks(blocks); - for (auto block : blocks) { - if (block->get_desc() == "Conv2d") { - auto conv_block = (Conv2d*)block; - conv_block->enable_direct(); - } - } - } - std::string get_desc() override { return "taesd"; } diff --git a/unet.hpp b/unet.hpp index 5153e5ed5..91af9f7ce 100644 --- a/unet.hpp +++ b/unet.hpp @@ -588,18 +588,6 @@ struct UNetModelRunner : public GGMLRunner { unet.init(params_ctx, tensor_types, prefix); } - void enable_conv2d_direct() { - std::vector blocks; - unet.get_all_blocks(blocks); - for (auto block : blocks) { - if (block->get_desc() == "Conv2d") { - LOG_DEBUG("block %s", block->get_desc().c_str()); - auto conv_block = (Conv2d*)block; - conv_block->enable_direct(); - } - } - } - std::string get_desc() override { return "unet"; } diff --git a/upscaler.cpp b/upscaler.cpp index 4081150ed..a9e5f6a17 100644 --- a/upscaler.cpp +++ b/upscaler.cpp @@ -53,7 +53,7 @@ struct UpscalerGGML { LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type)); esrgan_upscaler = std::make_shared(backend, offload_params_to_cpu, model_loader.tensor_storages_types); if (direct) { - esrgan_upscaler->enable_conv2d_direct(); + esrgan_upscaler->set_conv2d_direct_enabled(true); } if (!esrgan_upscaler->load_from_file(esrgan_path, n_threads)) { return false; diff --git a/vae.hpp b/vae.hpp index cad5961de..8c82d2f89 100644 --- a/vae.hpp +++ b/vae.hpp @@ -529,7 +529,6 @@ struct VAE : public GGMLRunner { struct ggml_tensor** output, struct ggml_context* output_ctx) = 0; virtual void get_param_tensors(std::map& tensors, const std::string prefix) = 0; - virtual void enable_conv2d_direct(){}; virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); }; }; @@ -572,17 +571,6 @@ struct AutoEncoderKL : public VAE { ae.init(params_ctx, tensor_types, prefix); } - void enable_conv2d_direct() override { - std::vector blocks; - ae.get_all_blocks(blocks); - for (auto block : blocks) { - if (block->get_desc() == "Conv2d") { - auto conv_block = (Conv2d*)block; - conv_block->enable_direct(); - } - } - } - void set_conv2d_scale(float scale) override { std::vector blocks; ae.get_all_blocks(blocks); From aba0997e94ad54027374cb69ac19fccc8f4a8f93 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 1 Nov 2025 00:41:16 +0800 Subject: [PATCH 4/4] format code --- common.hpp | 2 +- diffusion_model.hpp | 2 +- flux.hpp | 12 ++++++------ ggml_extend.hpp | 16 ++++++++-------- mmdit.hpp | 4 ++-- qwen_image.hpp | 4 ++-- wan.hpp | 16 ++++++++-------- 7 files changed, 28 insertions(+), 28 deletions(-) diff --git a/common.hpp b/common.hpp index 45d3de396..03c931bdf 100644 --- a/common.hpp +++ b/common.hpp @@ -339,7 +339,7 @@ class BasicTransformerBlock : public GGMLBlock { int64_t n_head, int64_t d_head, int64_t context_dim, - bool ff_in = false) + bool ff_in = false) : n_head(n_head), d_head(d_head), ff_in(ff_in) { // disable_self_attn is always False // disable_temporal_crossattention is always False diff --git a/diffusion_model.hpp b/diffusion_model.hpp index fd63050e6..651f7a45a 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -36,7 +36,7 @@ struct DiffusionModel { virtual void get_param_tensors(std::map& tensors) = 0; virtual size_t get_params_buffer_size() = 0; virtual int64_t get_adm_in_channels() = 0; - virtual void set_flash_attn_enabled(bool enabled) = 0; + virtual void set_flash_attn_enabled(bool enabled) = 0; }; struct UNetModel : public DiffusionModel { diff --git a/flux.hpp b/flux.hpp index eb149256f..9dd2c9f7f 100644 --- a/flux.hpp +++ b/flux.hpp @@ -126,9 +126,9 @@ namespace Flux { // x: [N, n_token, dim] // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] - auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] x = Rope::attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -203,9 +203,9 @@ namespace Flux { DoubleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio, - int idx = 0, - bool qkv_bias = false, - bool prune_mod = false) + int idx = 0, + bool qkv_bias = false, + bool prune_mod = false) : idx(idx), prune_mod(prune_mod) { int64_t mlp_hidden_dim = hidden_size * mlp_ratio; if (!prune_mod) { @@ -313,7 +313,7 @@ namespace Flux { auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 61e15c337..41d59e485 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1158,7 +1158,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context bool diag_mask_inf = false, bool skip_reshape = false, bool flash_attn = false, - float kv_scale = 1.0f ) { // avoid overflow + float kv_scale = 1.0f) { // avoid overflow int64_t L_q; int64_t L_k; int64_t C; @@ -1463,9 +1463,9 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) { typedef std::map String2GGMLType; struct GGMLRunnerContext { - ggml_backend_t backend = nullptr; - ggml_context* ggml_ctx = nullptr; - bool flash_attn_enabled = false; + ggml_backend_t backend = nullptr; + ggml_context* ggml_ctx = nullptr; + bool flash_attn_enabled = false; bool conv2d_direct_enabled = false; }; @@ -1495,7 +1495,7 @@ struct GGMLRunner { std::map cache_tensor_map; // name -> tensor const std::string final_result_name = "ggml_runner_final_result_tensor"; - bool flash_attn_enabled = false; + bool flash_attn_enabled = false; bool conv2d_direct_enabled = false; void alloc_params_ctx() { @@ -1756,9 +1756,9 @@ struct GGMLRunner { virtual GGMLRunnerContext get_context() { GGMLRunnerContext runner_ctx; - runner_ctx.ggml_ctx = compute_ctx; - runner_ctx.backend = runtime_backend; - runner_ctx.flash_attn_enabled = flash_attn_enabled; + runner_ctx.ggml_ctx = compute_ctx; + runner_ctx.backend = runtime_backend; + runner_ctx.flash_attn_enabled = flash_attn_enabled; runner_ctx.conv2d_direct_enabled = conv2d_direct_enabled; return runner_ctx; } diff --git a/mmdit.hpp b/mmdit.hpp index 18e921d49..6189783c5 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -208,7 +208,7 @@ class SelfAttention : public GGMLBlock { struct ggml_tensor* x) { auto qkv = pre_attention(ctx, x); x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -497,7 +497,7 @@ block_mixing(GGMLRunnerContext* ctx, } auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size] - attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] + attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] auto context_attn = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], diff --git a/qwen_image.hpp b/qwen_image.hpp index ea5b3d310..6288d9aae 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -158,7 +158,7 @@ namespace Qwen { auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, @@ -191,7 +191,7 @@ namespace Qwen { QwenImageTransformerBlock(int64_t dim, int64_t num_attention_heads, int64_t attention_head_dim, - float eps = 1e-6) { + float eps = 1e-6) { // img_mod.0 is nn.SiLU() blocks["img_mod.1"] = std::shared_ptr(new Linear(dim, 6 * dim, true)); diff --git a/wan.hpp b/wan.hpp index b29d73a40..db4f5aaaa 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1291,8 +1291,8 @@ namespace WAN { public: WanSelfAttention(int64_t dim, int64_t num_heads, - bool qk_norm = true, - float eps = 1e-6) + bool qk_norm = true, + float eps = 1e-6) : num_heads(num_heads) { head_dim = dim / num_heads; blocks["q"] = std::shared_ptr(new Linear(dim, dim)); @@ -1347,8 +1347,8 @@ namespace WAN { public: WanCrossAttention(int64_t dim, int64_t num_heads, - bool qk_norm = true, - float eps = 1e-6) + bool qk_norm = true, + float eps = 1e-6) : WanSelfAttention(dim, num_heads, qk_norm, eps) {} virtual struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, @@ -1360,8 +1360,8 @@ namespace WAN { public: WanT2VCrossAttention(int64_t dim, int64_t num_heads, - bool qk_norm = true, - float eps = 1e-6) + bool qk_norm = true, + float eps = 1e-6) : WanCrossAttention(dim, num_heads, qk_norm, eps) {} struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, @@ -1398,8 +1398,8 @@ namespace WAN { public: WanI2VCrossAttention(int64_t dim, int64_t num_heads, - bool qk_norm = true, - float eps = 1e-6) + bool qk_norm = true, + float eps = 1e-6) : WanCrossAttention(dim, num_heads, qk_norm, eps) { blocks["k_img"] = std::shared_ptr(new Linear(dim, dim)); blocks["v_img"] = std::shared_ptr(new Linear(dim, dim));