From fb748bb8a48edf0ba0424d3ee761eb0af71ab0f5 Mon Sep 17 00:00:00 2001 From: stduhpf Date: Fri, 7 Nov 2025 15:58:59 +0100 Subject: [PATCH 1/7] fix: TAE encoding (#935) --- stable-diffusion.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 9faba955..d285f6ca 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1645,7 +1645,9 @@ class StableDiffusionGGML { } else { latent = gaussian_latent_sample(work_ctx, vae_output); } - process_latent_in(latent); + if (!use_tiny_autoencoder) { + process_latent_in(latent); + } if (sd_version_is_qwen_image(version)) { latent = ggml_reshape_4d(work_ctx, latent, latent->ne[0], latent->ne[1], latent->ne[3], 1); } From c2d8ffc22c030bbbd00c9431dc08db43fe2a3ff5 Mon Sep 17 00:00:00 2001 From: leejet Date: Fri, 7 Nov 2025 23:04:41 +0800 Subject: [PATCH 2/7] fix: compatibility for models with modified tensor shapes (#951) --- common.hpp | 16 ++++++++++++++++ ggml_extend.hpp | 2 +- vae.hpp | 19 +++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/common.hpp b/common.hpp index 59540752..c68ddafe 100644 --- a/common.hpp +++ b/common.hpp @@ -410,6 +410,22 @@ class SpatialTransformer : public GGMLBlock { int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2 bool use_linear = false; + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") { + auto iter = tensor_storage_map.find(prefix + "proj_out.weight"); + if (iter != tensor_storage_map.end()) { + int64_t inner_dim = n_head * d_head; + if (iter->second.n_dims == 4 && use_linear) { + use_linear = false; + blocks["proj_in"] = std::make_shared(in_channels, inner_dim, std::pair{1, 1}); + blocks["proj_out"] = std::make_shared(inner_dim, in_channels, std::pair{1, 1}); + } else if (iter->second.n_dims == 2 && !use_linear) { + use_linear = true; + blocks["proj_in"] = std::make_shared(in_channels, inner_dim); + blocks["proj_out"] = std::make_shared(inner_dim, in_channels); + } + } + } + public: SpatialTransformer(int64_t in_channels, int64_t n_head, diff --git a/ggml_extend.hpp b/ggml_extend.hpp index d11e07a1..ac6a2ccc 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1926,8 +1926,8 @@ class GGMLBlock { if (prefix.size() > 0) { prefix = prefix + "."; } - init_blocks(ctx, tensor_storage_map, prefix); init_params(ctx, tensor_storage_map, prefix); + init_blocks(ctx, tensor_storage_map, prefix); } size_t get_params_num() { diff --git a/vae.hpp b/vae.hpp index ddf970c9..9fc8fb75 100644 --- a/vae.hpp +++ b/vae.hpp @@ -66,6 +66,25 @@ class AttnBlock : public UnaryBlock { int64_t in_channels; bool use_linear; + void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") { + auto iter = tensor_storage_map.find(prefix + "proj_out.weight"); + if (iter != tensor_storage_map.end()) { + if (iter->second.n_dims == 4 && use_linear) { + use_linear = false; + blocks["q"] = std::make_shared(in_channels, in_channels, std::pair{1, 1}); + blocks["k"] = std::make_shared(in_channels, in_channels, std::pair{1, 1}); + blocks["v"] = std::make_shared(in_channels, in_channels, std::pair{1, 1}); + blocks["proj_out"] = std::make_shared(in_channels, in_channels, std::pair{1, 1}); + } else if (iter->second.n_dims == 2 && !use_linear) { + use_linear = true; + blocks["q"] = std::make_shared(in_channels, in_channels); + blocks["k"] = std::make_shared(in_channels, in_channels); + blocks["v"] = std::make_shared(in_channels, in_channels); + blocks["proj_out"] = std::make_shared(in_channels, in_channels); + } + } + } + public: AttnBlock(int64_t in_channels, bool use_linear) : in_channels(in_channels), use_linear(use_linear) { From 0fa3e1a3830f2708a3f8ae04c51de33e76dbd4f4 Mon Sep 17 00:00:00 2001 From: akleine Date: Sun, 9 Nov 2025 15:36:43 +0100 Subject: [PATCH 3/7] fix: prevent core dump in PM V2 in case of incomplete cmd line (#950) --- docs/photo_maker.md | 4 ++-- stable-diffusion.cpp | 30 ++++++++++++++++++------------ 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/docs/photo_maker.md b/docs/photo_maker.md index dae2c9b2..31203ef7 100644 --- a/docs/photo_maker.md +++ b/docs/photo_maker.md @@ -40,7 +40,7 @@ Running PMV2 is now a two-step process: ``` python face_detect.py input_image_dir ``` -An ```id_embeds.safetensors``` file will be generated in ```input_images_dir``` +An ```id_embeds.bin``` file will be generated in ```input_images_dir``` **Note: this step is only needed to run once; the same ```id_embeds``` can be reused** @@ -48,6 +48,6 @@ An ```id_embeds.safetensors``` file will be generated in ```input_images_dir``` You can download ```photomaker-v2.safetensors``` from [here](https://huggingface.co/bssrdf/PhotoMakerV2) -- All the command line parameters from Version 1 remain the same for Version 2 +- All the command line parameters from Version 1 remain the same for Version 2 plus one extra pointing to a valid ```id_embeds``` file: --pm-id-embed-path [path_to__id_embeds.bin] diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index d285f6ca..88806c86 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2222,18 +2222,24 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path); // print_ggml_tensor(id_embeds, true, "id_embeds:"); } - id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask); - int64_t t1 = ggml_time_ms(); - LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0); - if (sd_ctx->sd->free_params_immediately) { - sd_ctx->sd->pmid_model->free_params_buffer(); - } - // Encode input prompt without the trigger word for delayed conditioning - prompt_text_only = sd_ctx->sd->cond_stage_model->remove_trigger_from_prompt(work_ctx, prompt); - // printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str()); - prompt = prompt_text_only; // - if (sample_steps < 50) { - LOG_WARN("It's recommended to use >= 50 steps for photo maker!"); + if (pmv2 && id_embeds == nullptr) { + LOG_WARN("Provided PhotoMaker images, but NO valid ID embeds file for PM v2"); + LOG_WARN("Turn off PhotoMaker"); + sd_ctx->sd->stacked_id = false; + } else { + id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask); + int64_t t1 = ggml_time_ms(); + LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0); + if (sd_ctx->sd->free_params_immediately) { + sd_ctx->sd->pmid_model->free_params_buffer(); + } + // Encode input prompt without the trigger word for delayed conditioning + prompt_text_only = sd_ctx->sd->cond_stage_model->remove_trigger_from_prompt(work_ctx, prompt); + // printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str()); + prompt = prompt_text_only; // + if (sample_steps < 50) { + LOG_WARN("It's recommended to use >= 50 steps for photo maker!"); + } } } else { LOG_WARN("Provided PhotoMaker model file, but NO input ID images"); From d2d3944f500cb9acc6652f3e9dc4a3676e7387de Mon Sep 17 00:00:00 2001 From: akleine Date: Sun, 9 Nov 2025 15:47:37 +0100 Subject: [PATCH 4/7] feat: add support for SD2.x with TINY U-Nets (#939) --- docs/distilled_sd.md | 87 +++++++++++++++++++++++++------------------- model.cpp | 3 ++ model.h | 3 +- stable-diffusion.cpp | 1 + unet.hpp | 21 ++++++----- 5 files changed, 68 insertions(+), 47 deletions(-) diff --git a/docs/distilled_sd.md b/docs/distilled_sd.md index f235f56b..478305f2 100644 --- a/docs/distilled_sd.md +++ b/docs/distilled_sd.md @@ -1,40 +1,66 @@ -# Running distilled models: SSD1B and SD1.x with tiny U-Nets +# Running distilled models: SSD1B and SDx.x with tiny U-Nets -## Preface +## Preface -This kind of models have a reduced U-Net part. -Unlike other SDXL models the U-Net of SSD1B has only one middle block and lesser attention layers in up and down blocks, resulting in relatively smaller files. Running these models saves more than 33% of the time. For more details, refer to Segmind's paper on https://arxiv.org/abs/2401.02677v1 . -Unlike other SD 1.x models Tiny-UNet models consist of only 6 U-Net blocks, resulting in relatively smaller files (approximately 1 GB). Running these models saves almost 50% of the time. For more details, refer to the paper: https://arxiv.org/pdf/2305.15798.pdf . +These models feature a reduced U-Net architecture. Unlike standard SDXL models, the SSD-1B U-Net contains only one middle block and fewer attention layers in its up- and down-blocks, resulting in significantly smaller file sizes. Using these models can reduce inference time by more than 33%. For more details, refer to Segmind's paper: https://arxiv.org/abs/2401.02677v1. +Similarly, SD1.x- and SD2.x-style models with a tiny U-Net consist of only 6 U-Net blocks, leading to very small files and time savings of up to 50%. For more information, see the paper: https://arxiv.org/pdf/2305.15798.pdf. ## SSD1B -Unfortunately not all of this models follow the standard model parameter naming mapping. -Anyway there are some very useful SSD1B models available online, such as: +Note that not all of these models follow the standard parameter naming conventions. However, several useful SSD-1B models are available online, such as: * https://huggingface.co/segmind/SSD-1B/resolve/main/SSD-1B-A1111.safetensors - * https://huggingface.co/hassenhamdi/SSD-1B-fp8_e4m3fn/resolve/main/SSD-1B_fp8_e4m3fn.safetensors + * https://huggingface.co/hassenhamdi/SSD-1B-fp8_e4m3fn/resolve/main/SSD-1B_fp8_e4m3fn.safetensors -Also there are useful LORAs available: +Useful LoRAs are also available: * https://huggingface.co/seungminh/lora-swarovski-SSD-1B/resolve/main/pytorch_lora_weights.safetensors - * https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors + * https://huggingface.co/kylielee505/mylcmlorassd/resolve/main/pytorch_lora_weights.safetensors -You can use this files **out-of-the-box** - unlike models in next section. +These files can be used out-of-the-box, unlike the models described in the next section. -## SD1.x with tiny U-Nets +## SD1.x, SD2.x with tiny U-Nets -There are some Tiny SD 1.x models available online, such as: +These models require conversion before use. You will need a Python script provided by the diffusers team, available on GitHub: + + * https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/scripts/convert_diffusers_to_original_stable_diffusion.py + +### SD2.x + +NotaAI provides the following model online: + +* https://huggingface.co/nota-ai/bk-sdm-v2-tiny + +Creating a .safetensors file involves two steps. First, run this short Python script to download the model from Hugging Face: + +```python +from diffusers import StableDiffusionPipeline +pipe = StableDiffusionPipeline.from_pretrained("nota-ai/bk-sdm-v2-tiny",cache_dir="./") +``` + +Second, create the .safetensors file by running: + +```bash +python convert_diffusers_to_original_stable_diffusion.py \ + --model_path models--nota-ai--bk-sdm-v2-tiny/snapshots/68277af553777858cd47e133f92e4db47321bc74 \ + --checkpoint_path bk-sdm-v2-tiny.safetensors --half --use_safetensors +``` + +This will generate the **file bk-sdm-v2-tiny.safetensors**, which is now ready for use with sd.cpp. + +### SD1.x + +Several Tiny SD 1.x models are available online, such as: * https://huggingface.co/segmind/tiny-sd * https://huggingface.co/segmind/portrait-finetuned * https://huggingface.co/nota-ai/bk-sdm-tiny -These models need some conversion, for example because partially tensors are **non contiguous** stored. To create a usable checkpoint file, follow these **easy** steps: +These models also require conversion, partly because some tensors are stored in a non-contiguous manner. To create a usable checkpoint file, follow these simple steps: +Download and prepare the model using Python: -### Download model from Hugging Face - -Download the model using Python on your computer, for example this way: +##### Download the model using Python on your computer, for example this way: ```python import torch @@ -46,35 +72,22 @@ for param in unet.parameters(): pipe.save_pretrained("segmindtiny-sd", safe_serialization=True) ``` -### Convert that to a ckpt file - -To convert the downloaded model to a checkpoint file, you need another Python script. Download the conversion script from here: - - * https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/scripts/convert_diffusers_to_original_stable_diffusion.py - - -### Run convert script - -Now, run that conversion script: +##### Run the conversion script: ```bash python convert_diffusers_to_original_stable_diffusion.py \ - --model_path ./segmindtiny-sd \ - --checkpoint_path ./segmind_tiny-sd.ckpt --half + --model_path ./segmindtiny-sd \ + --checkpoint_path ./segmind_tiny-sd.ckpt --half ``` -The file **segmind_tiny-sd.ckpt** will be generated and is now ready to use with sd.cpp +The file segmind_tiny-sd.ckpt will be generated and is now ready for use with sd.cpp. You can follow a similar process for the other models mentioned above. -You can follow a similar process for other models mentioned above from Hugging Face. - -### Another ckpt file on the net - -There is another model file available online: +### Another available .ckpt file: * https://huggingface.co/ClashSAN/small-sd/resolve/main/tinySDdistilled.ckpt - -If you want to use that, you have to adjust some **non-contiguous tensors** first: + +To use this file, you must first adjust its non-contiguous tensors: ```python import torch diff --git a/model.cpp b/model.cpp index cec69663..79fde749 100644 --- a/model.cpp +++ b/model.cpp @@ -1788,6 +1788,9 @@ SDVersion ModelLoader::get_sd_version() { if (is_inpaint) { return VERSION_SD2_INPAINT; } + if (!has_middle_block_1) { + return VERSION_SD2_TINY_UNET; + } return VERSION_SD2; } return VERSION_COUNT; diff --git a/model.h b/model.h index a29160cf..583a0146 100644 --- a/model.h +++ b/model.h @@ -26,6 +26,7 @@ enum SDVersion { VERSION_SD1_TINY_UNET, VERSION_SD2, VERSION_SD2_INPAINT, + VERSION_SD2_TINY_UNET, VERSION_SDXL, VERSION_SDXL_INPAINT, VERSION_SDXL_PIX2PIX, @@ -52,7 +53,7 @@ static inline bool sd_version_is_sd1(SDVersion version) { } static inline bool sd_version_is_sd2(SDVersion version) { - if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT) { + if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT || version == VERSION_SD2_TINY_UNET) { return true; } return false; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 88806c86..8976a447 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -23,6 +23,7 @@ const char* model_version_to_str[] = { "SD 1.x Tiny UNet", "SD 2.x", "SD 2.x Inpaint", + "SD 2.x Tiny UNet", "SDXL", "SDXL Inpaint", "SDXL Instruct-Pix2Pix", diff --git a/unet.hpp b/unet.hpp index 0e0d049b..8f0adf38 100644 --- a/unet.hpp +++ b/unet.hpp @@ -180,6 +180,7 @@ class UnetModelBlock : public GGMLBlock { int num_head_channels = -1; // channels // num_heads int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL bool use_linear_projection = false; + bool tiny_unet = false; public: int model_channels = 320; @@ -208,15 +209,17 @@ class UnetModelBlock : public GGMLBlock { num_head_channels = 64; num_heads = -1; use_linear_projection = true; - } else if (version == VERSION_SD1_TINY_UNET) { - num_res_blocks = 1; - channel_mult = {1, 2, 4}; } if (sd_version_is_inpaint(version)) { in_channels = 9; } else if (sd_version_is_unet_edit(version)) { in_channels = 8; } + if (version == VERSION_SD1_TINY_UNET || version == VERSION_SD2_TINY_UNET) { + num_res_blocks = 1; + channel_mult = {1, 2, 4}; + tiny_unet = true; + } // dims is always 2 // use_temporal_attention is always True for SVD @@ -290,7 +293,7 @@ class UnetModelBlock : public GGMLBlock { context_dim)); } input_block_chans.push_back(ch); - if (version == VERSION_SD1_TINY_UNET) { + if (tiny_unet) { input_block_idx++; } } @@ -311,7 +314,7 @@ class UnetModelBlock : public GGMLBlock { d_head = num_head_channels; n_head = ch / d_head; } - if (version != VERSION_SD1_TINY_UNET) { + if (!tiny_unet) { blocks["middle_block.0"] = std::shared_ptr(get_resblock(ch, time_embed_dim, ch)); if (version != VERSION_SDXL_SSD1B) { blocks["middle_block.1"] = std::shared_ptr(get_attention_layer(ch, @@ -358,7 +361,7 @@ class UnetModelBlock : public GGMLBlock { } if (i > 0 && j == num_res_blocks) { - if (version == VERSION_SD1_TINY_UNET) { + if (tiny_unet) { output_block_idx++; if (output_block_idx == 2) { up_sample_idx = 1; @@ -495,7 +498,7 @@ class UnetModelBlock : public GGMLBlock { } hs.push_back(h); } - if (version == VERSION_SD1_TINY_UNET) { + if (tiny_unet) { input_block_idx++; } if (i != len_mults - 1) { @@ -512,7 +515,7 @@ class UnetModelBlock : public GGMLBlock { // [N, 4*model_channels, h/8, w/8] // middle_block - if (version != VERSION_SD1_TINY_UNET) { + if (!tiny_unet) { h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8] if (version != VERSION_SDXL_SSD1B) { h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8] @@ -554,7 +557,7 @@ class UnetModelBlock : public GGMLBlock { } if (i > 0 && j == num_res_blocks) { - if (version == VERSION_SD1_TINY_UNET) { + if (tiny_unet) { output_block_idx++; if (output_block_idx == 2) { up_sample_idx = 1; From ee89afc878d6cd84433f1768b6a64982edfb9757 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 9 Nov 2025 22:47:53 +0800 Subject: [PATCH 5/7] fix: resolve issue with pmid (#957) --- stable-diffusion.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 8976a447..675a02d4 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -2200,7 +2200,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, } ggml_ext_tensor_iter(init_img, [&](ggml_tensor* init_img, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { - float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2); + float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2, false); ggml_ext_tensor_set_f32(init_img, value, i0, i1, i2, i3); }); From 8ecdf053acaf48326e50eab1effd1fcbc4db28a2 Mon Sep 17 00:00:00 2001 From: stduhpf Date: Sun, 9 Nov 2025 17:12:02 +0100 Subject: [PATCH 6/7] feat: add image preview support (#522) --- .gitignore | 1 + examples/cli/README.md | 6 ++ examples/cli/main.cpp | 92 +++++++++++++++- ggml_extend.hpp | 2 +- latent-preview.h | 173 +++++++++++++++++++++++++++++ stable-diffusion.cpp | 240 +++++++++++++++++++++++++++++++++++++++-- stable-diffusion.h | 13 +++ util.cpp | 37 +++++++ util.h | 9 ++ 9 files changed, 563 insertions(+), 10 deletions(-) create mode 100644 latent-preview.h diff --git a/.gitignore b/.gitignore index dd4f6435..b0e3af83 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ test/ output*.png models* *.log +preview.png diff --git a/examples/cli/README.md b/examples/cli/README.md index ee17d17d..00e0942f 100644 --- a/examples/cli/README.md +++ b/examples/cli/README.md @@ -32,6 +32,7 @@ Options: -o, --output path to write result image to (default: ./output.png) -p, --prompt the prompt to render -n, --negative-prompt the negative prompt (default: "") + --preview-path path to write preview image to (default: ./preview.png) --upscale-model path to esrgan model. -t, --threads number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of CPU physical cores @@ -48,6 +49,8 @@ Options: --fps fps (default: 24) --timestep-shift shift timestep for NitroFusion models (default: 0). recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant + --preview-interval interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at + every step) --cfg-scale unconditional guidance scale: (default: 7.0) --img-cfg-scale image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale) --guidance distilled guidance scale for models with guidance input (default: 3.5) @@ -86,6 +89,8 @@ Options: --chroma-enable-t5-mask enable t5 mask for chroma --increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1). --disable-auto-resize-ref-image disable auto resize of ref images + --taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview tae) + --preview-noisy enables previewing noisy inputs of the models rather than the denoised outputs -M, --mode run mode, one of [img_gen, vid_gen, upscale, convert], default: img_gen --type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the type of the weight file @@ -107,4 +112,5 @@ Options: --vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32) --vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size) + --preview preview method. must be one of the following [none, proj, tae, vae] (default is none) ``` \ No newline at end of file diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 8f938c9b..619c4284 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -46,6 +46,13 @@ const char* modes_str[] = { }; #define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale" +const char* previews_str[] = { + "none", + "proj", + "tae", + "vae", +}; + enum SDMode { IMG_GEN, VID_GEN, @@ -135,6 +142,12 @@ struct SDParams { sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; bool force_sdxl_vae_conv_scale = false; + preview_t preview_method = PREVIEW_NONE; + int preview_interval = 1; + std::string preview_path = "preview.png"; + bool taesd_preview = false; + bool preview_noisy = false; + SDParams() { sd_sample_params_init(&sample_params); sd_sample_params_init(&high_noise_sample_params); @@ -210,6 +223,8 @@ void print_params(SDParams params) { printf(" video_frames: %d\n", params.video_frames); printf(" vace_strength: %.2f\n", params.vace_strength); printf(" fps: %d\n", params.fps); + printf(" preview_mode: %s (%s)\n", previews_str[params.preview_method], params.preview_noisy ? "noisy" : "denoised"); + printf(" preview_interval: %d\n", params.preview_interval); free(sample_params_str); free(high_noise_sample_params_str); } @@ -589,6 +604,10 @@ void parse_args(int argc, const char** argv, SDParams& params) { "--negative-prompt", "the negative prompt (default: \"\")", ¶ms.negative_prompt}, + {"", + "--preview-path", + "path to write preview image to (default: ./preview.png)", + ¶ms.preview_path}, {"", "--upscale-model", "path to esrgan model.", @@ -647,6 +666,10 @@ void parse_args(int argc, const char** argv, SDParams& params) { "shift timestep for NitroFusion models (default: 0). " "recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant", ¶ms.sample_params.shifted_timestep}, + {"", + "--preview-interval", + "interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)", + ¶ms.preview_interval}, }; options.float_options = { @@ -801,7 +824,14 @@ void parse_args(int argc, const char** argv, SDParams& params) { "--disable-auto-resize-ref-image", "disable auto resize of ref images", false, ¶ms.auto_resize_ref_image}, - }; + {"", + "--taesd-preview-only", + std::string("prevents usage of taesd for decoding the final image. (for use with --preview ") + previews_str[PREVIEW_TAE] + ")", + true, ¶ms.taesd_preview}, + {"", + "--preview-noisy", + "enables previewing noisy inputs of the models rather than the denoised outputs", + true, ¶ms.preview_noisy}}; auto on_mode_arg = [&](int argc, const char** argv, int index) { if (++index >= argc) { @@ -1046,6 +1076,26 @@ void parse_args(int argc, const char** argv, SDParams& params) { return 1; }; + auto on_preview_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* preview = argv[index]; + int preview_method = -1; + for (int m = 0; m < PREVIEW_COUNT; m++) { + if (!strcmp(preview, previews_str[m])) { + preview_method = m; + } + } + if (preview_method == -1) { + fprintf(stderr, "error: preview method %s\n", + preview); + return -1; + } + params.preview_method = (preview_t)preview_method; + return 1; + }; + options.manual_options = { {"-M", "--mode", @@ -1110,6 +1160,10 @@ void parse_args(int argc, const char** argv, SDParams& params) { "--vae-relative-tile-size", "relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)", on_relative_tile_size_arg}, + {"", + "--preview", + std::string("preview method. must be one of the following [") + previews_str[0] + ", " + previews_str[1] + ", " + previews_str[2] + ", " + previews_str[3] + "] (default is " + previews_str[PREVIEW_NONE] + ")\n", + on_preview_arg}, }; if (!parse_options(argc, argv, options)) { @@ -1452,15 +1506,50 @@ bool load_images_from_dir(const std::string dir, return true; } +const char* preview_path; +float preview_fps; + +void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy) { + (void)step; + (void)is_noisy; + // is_noisy is set to true if the preview corresponds to noisy latents, false if it's denoised latents + // unused in this app, it will either be always noisy or always denoised here + if (frame_count == 1) { + stbi_write_png(preview_path, image->width, image->height, image->channel, image->data, 0); + } else { + create_mjpg_avi_from_sd_images(preview_path, image, frame_count, preview_fps); + } +} + int main(int argc, const char* argv[]) { SDParams params; parse_args(argc, argv, params); + preview_path = params.preview_path.c_str(); + if (params.video_frames > 4) { + size_t last_dot_pos = params.preview_path.find_last_of("."); + std::string base_path = params.preview_path; + std::string file_ext = ""; + if (last_dot_pos != std::string::npos) { // filename has extension + base_path = params.preview_path.substr(0, last_dot_pos); + file_ext = params.preview_path.substr(last_dot_pos); + std::transform(file_ext.begin(), file_ext.end(), file_ext.begin(), ::tolower); + } + if (file_ext == ".png") { + base_path = base_path + ".avi"; + preview_path = base_path.c_str(); + } + } + preview_fps = params.fps; + if (params.preview_method == PREVIEW_PROJ) + preview_fps /= 4.0f; + params.sample_params.guidance.slg.layers = params.skip_layers.data(); params.sample_params.guidance.slg.layer_count = params.skip_layers.size(); params.high_noise_sample_params.guidance.slg.layers = params.high_noise_skip_layers.data(); params.high_noise_sample_params.guidance.slg.layer_count = params.high_noise_skip_layers.size(); sd_set_log_callback(sd_log_cb, (void*)¶ms); + sd_set_preview_callback((sd_preview_cb_t)step_callback, params.preview_method, params.preview_interval, !params.preview_noisy, params.preview_noisy); if (params.verbose) { print_params(params); @@ -1654,6 +1743,7 @@ int main(int argc, const char* argv[]) { params.control_net_cpu, params.vae_on_cpu, params.diffusion_flash_attn, + params.taesd_preview, params.diffusion_conv_direct, params.vae_conv_direct, params.force_sdxl_vae_conv_scale, diff --git a/ggml_extend.hpp b/ggml_extend.hpp index ac6a2ccc..a634bc80 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -875,7 +875,7 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], input->ne[3]); ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], output->ne[3]); int num_tiles = num_tiles_x * num_tiles_y; - LOG_INFO("processing %i tiles", num_tiles); + LOG_DEBUG("processing %i tiles", num_tiles); pretty_progress(0, num_tiles, 0.0f); int tile_count = 1; bool last_y = false, last_x = false; diff --git a/latent-preview.h b/latent-preview.h new file mode 100644 index 00000000..97409a7d --- /dev/null +++ b/latent-preview.h @@ -0,0 +1,173 @@ +#include +#include +#include "ggml.h" + +const float wan_21_latent_rgb_proj[16][3] = { + {0.015123f, -0.148418f, 0.479828f}, + {0.003652f, -0.010680f, -0.037142f}, + {0.212264f, 0.063033f, 0.016779f}, + {0.232999f, 0.406476f, 0.220125f}, + {-0.051864f, -0.082384f, -0.069396f}, + {0.085005f, -0.161492f, 0.010689f}, + {-0.245369f, -0.506846f, -0.117010f}, + {-0.151145f, 0.017721f, 0.007207f}, + {-0.293239f, -0.207936f, -0.421135f}, + {-0.187721f, 0.050783f, 0.177649f}, + {-0.013067f, 0.265964f, 0.166578f}, + {0.028327f, 0.109329f, 0.108642f}, + {-0.205343f, 0.043991f, 0.148914f}, + {0.014307f, -0.048647f, -0.007219f}, + {0.217150f, 0.053074f, 0.319923f}, + {0.155357f, 0.083156f, 0.064780f}}; +float wan_21_latent_rgb_bias[3] = {-0.270270f, -0.234976f, -0.456853f}; + +const float wan_22_latent_rgb_proj[48][3] = { + {0.017126f, -0.027230f, -0.019257f}, + {-0.113739f, -0.028715f, -0.022885f}, + {-0.000106f, 0.021494f, 0.004629f}, + {-0.013273f, -0.107137f, -0.033638f}, + {-0.000381f, 0.000279f, 0.025877f}, + {-0.014216f, -0.003975f, 0.040528f}, + {0.001638f, -0.000748f, 0.011022f}, + {0.029238f, -0.006697f, 0.035933f}, + {0.021641f, -0.015874f, 0.040531f}, + {-0.101984f, -0.070160f, -0.028855f}, + {0.033207f, -0.021068f, 0.002663f}, + {-0.104711f, 0.121673f, 0.102981f}, + {0.082647f, -0.004991f, 0.057237f}, + {-0.027375f, 0.031581f, 0.006868f}, + {-0.045434f, 0.029444f, 0.019287f}, + {-0.046572f, -0.012537f, 0.006675f}, + {0.074709f, 0.033690f, 0.025289f}, + {-0.008251f, -0.002745f, -0.006999f}, + {0.012685f, -0.061856f, -0.048658f}, + {0.042304f, -0.007039f, 0.000295f}, + {-0.007644f, -0.060843f, -0.033142f}, + {0.159909f, 0.045628f, 0.367541f}, + {0.095171f, 0.086438f, 0.010271f}, + {0.006812f, 0.019643f, 0.029637f}, + {0.003467f, -0.010705f, 0.014252f}, + {-0.099681f, -0.066272f, -0.006243f}, + {0.047357f, 0.037040f, 0.000185f}, + {-0.041797f, -0.089225f, -0.032257f}, + {0.008928f, 0.017028f, 0.018684f}, + {-0.042255f, 0.016045f, 0.006849f}, + {0.011268f, 0.036462f, 0.037387f}, + {0.011553f, -0.016375f, -0.048589f}, + {0.046266f, -0.027189f, 0.056979f}, + {0.009640f, -0.017576f, 0.030324f}, + {-0.045794f, -0.036083f, -0.010616f}, + {0.022418f, 0.039783f, -0.032939f}, + {-0.052714f, -0.015525f, 0.007438f}, + {0.193004f, 0.223541f, 0.264175f}, + {-0.059406f, -0.008188f, 0.022867f}, + {-0.156742f, -0.263791f, -0.007385f}, + {-0.015717f, 0.016570f, 0.033969f}, + {0.037969f, 0.109835f, 0.200449f}, + {-0.000782f, -0.009566f, -0.008058f}, + {0.010709f, 0.052960f, -0.044195f}, + {0.017271f, 0.045839f, 0.034569f}, + {0.009424f, 0.013088f, -0.001714f}, + {-0.024805f, -0.059378f, -0.033756f}, + {-0.078293f, 0.029070f, 0.026129f}}; +float wan_22_latent_rgb_bias[3] = {0.013160f, -0.096492f, -0.071323f}; + +const float flux_latent_rgb_proj[16][3] = { + {-0.041168f, 0.019917f, 0.097253f}, + {0.028096f, 0.026730f, 0.129576f}, + {0.065618f, -0.067950f, -0.014651f}, + {-0.012998f, -0.014762f, 0.081251f}, + {0.078567f, 0.059296f, -0.024687f}, + {-0.015987f, -0.003697f, 0.005012f}, + {0.033605f, 0.138999f, 0.068517f}, + {-0.024450f, -0.063567f, -0.030101f}, + {-0.040194f, -0.016710f, 0.127185f}, + {0.112681f, 0.088764f, -0.041940f}, + {-0.023498f, 0.093664f, 0.025543f}, + {0.082899f, 0.048320f, 0.007491f}, + {0.075712f, 0.074139f, 0.081965f}, + {-0.143501f, 0.018263f, -0.136138f}, + {-0.025767f, -0.082035f, -0.040023f}, + {-0.111849f, -0.055589f, -0.032361f}}; +float flux_latent_rgb_bias[3] = {0.024600f, -0.006937f, -0.008089f}; + +// This one was taken straight from +// https://github.com/Stability-AI/sd3.5/blob/8565799a3b41eb0c7ba976d18375f0f753f56402/sd3_impls.py#L288-L303 +// (MiT Licence) +const float sd3_latent_rgb_proj[16][3] = { + {-0.0645f, 0.0177f, 0.1052f}, + {0.0028f, 0.0312f, 0.0650f}, + {0.1848f, 0.0762f, 0.0360f}, + {0.0944f, 0.0360f, 0.0889f}, + {0.0897f, 0.0506f, -0.0364f}, + {-0.0020f, 0.1203f, 0.0284f}, + {0.0855f, 0.0118f, 0.0283f}, + {-0.0539f, 0.0658f, 0.1047f}, + {-0.0057f, 0.0116f, 0.0700f}, + {-0.0412f, 0.0281f, -0.0039f}, + {0.1106f, 0.1171f, 0.1220f}, + {-0.0248f, 0.0682f, -0.0481f}, + {0.0815f, 0.0846f, 0.1207f}, + {-0.0120f, -0.0055f, -0.0867f}, + {-0.0749f, -0.0634f, -0.0456f}, + {-0.1418f, -0.1457f, -0.1259f}, +}; +float sd3_latent_rgb_bias[3] = {0, 0, 0}; + +const float sdxl_latent_rgb_proj[4][3] = { + {0.258303f, 0.277640f, 0.329699f}, + {-0.299701f, 0.105446f, 0.014194f}, + {0.050522f, 0.186163f, -0.143257f}, + {-0.211938f, -0.149892f, -0.080036f}}; +float sdxl_latent_rgb_bias[3] = {0.144381f, -0.033313f, 0.007061f}; + +const float sd_latent_rgb_proj[4][3] = { + {0.337366f, 0.216344f, 0.257386f}, + {0.165636f, 0.386828f, 0.046994f}, + {-0.267803f, 0.237036f, 0.223517f}, + {-0.178022f, -0.200862f, -0.678514f}}; +float sd_latent_rgb_bias[3] = {-0.017478f, -0.055834f, -0.105825f}; + +void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int width, int height, int frames, int dim) { + size_t buffer_head = 0; + for (int k = 0; k < frames; k++) { + for (int j = 0; j < height; j++) { + for (int i = 0; i < width; i++) { + size_t latent_id = (i * latents->nb[0] + j * latents->nb[1] + k * latents->nb[2]); + float r = 0, g = 0, b = 0; + if (latent_rgb_proj != nullptr) { + for (int d = 0; d < dim; d++) { + float value = *(float*)((char*)latents->data + latent_id + d * latents->nb[ggml_n_dims(latents) - 1]); + r += value * latent_rgb_proj[d][0]; + g += value * latent_rgb_proj[d][1]; + b += value * latent_rgb_proj[d][2]; + } + } else { + // interpret first 3 channels as RGB + r = *(float*)((char*)latents->data + latent_id + 0 * latents->nb[ggml_n_dims(latents) - 1]); + g = *(float*)((char*)latents->data + latent_id + 1 * latents->nb[ggml_n_dims(latents) - 1]); + b = *(float*)((char*)latents->data + latent_id + 2 * latents->nb[ggml_n_dims(latents) - 1]); + } + if (latent_rgb_bias != nullptr) { + // bias + r += latent_rgb_bias[0]; + g += latent_rgb_bias[1]; + b += latent_rgb_bias[2]; + } + // change range + r = r * .5f + .5f; + g = g * .5f + .5f; + b = b * .5f + .5f; + + // clamp rgb values to [0,1] range + r = r >= 0 ? r <= 1 ? r : 1 : 0; + g = g >= 0 ? g <= 1 ? g : 1 : 0; + b = b >= 0 ? b <= 1 ? b : 1 : 0; + + buffer[buffer_head++] = (uint8_t)(r * 255); + buffer[buffer_head++] = (uint8_t)(g * 255); + buffer[buffer_head++] = (uint8_t)(b * 255); + } + } + } +} diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 675a02d4..90f005fe 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -16,6 +16,8 @@ #include "tae.hpp" #include "vae.hpp" +#include "latent-preview.h" + const char* model_version_to_str[] = { "SD 1.x", "SD 1.x Inpaint", @@ -74,6 +76,14 @@ void calculate_alphas_cumprod(float* alphas_cumprod, } } +void suppress_pp(int step, int steps, float time, void* data) { + (void)step; + (void)steps; + (void)time; + (void)data; + return; +} + /*=============================================== StableDiffusionGGML ================================================*/ class StableDiffusionGGML { @@ -487,7 +497,7 @@ class StableDiffusionGGML { } else if (version == VERSION_CHROMA_RADIANCE) { first_stage_model = std::make_shared(vae_backend, offload_params_to_cpu); - } else if (!use_tiny_autoencoder) { + } else if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) { first_stage_model = std::make_shared(vae_backend, offload_params_to_cpu, tensor_storage_map, @@ -510,7 +520,8 @@ class StableDiffusionGGML { } first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); - } else { + } + if (use_tiny_autoencoder) { tae_first_stage = std::make_shared(vae_backend, offload_params_to_cpu, tensor_storage_map, @@ -626,9 +637,10 @@ class StableDiffusionGGML { unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size(); } size_t vae_params_mem_size = 0; - if (!use_tiny_autoencoder) { + if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) { vae_params_mem_size = first_stage_model->get_params_buffer_size(); - } else { + } + if (use_tiny_autoencoder) { if (!tae_first_stage->load_from_file(taesd_path, n_threads)) { return false; } @@ -801,6 +813,7 @@ class StableDiffusionGGML { LOG_DEBUG("finished loaded file"); ggml_free(ctx); + use_tiny_autoencoder = use_tiny_autoencoder && !sd_ctx_params->tae_preview_only; return true; } @@ -1109,6 +1122,156 @@ class StableDiffusionGGML { } } + void silent_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) { + sd_progress_cb_t cb = sd_get_progress_callback(); + void* cbd = sd_get_progress_callback_data(); + sd_set_progress_callback((sd_progress_cb_t)suppress_pp, nullptr); + sd_tiling(input, output, scale, tile_size, tile_overlap_factor, on_processing); + sd_set_progress_callback(cb, cbd); + } + + void preview_image(ggml_context* work_ctx, + int step, + struct ggml_tensor* latents, + enum SDVersion version, + preview_t preview_mode, + ggml_tensor* result, + std::function step_callback, + bool is_noisy) { + const uint32_t channel = 3; + uint32_t width = latents->ne[0]; + uint32_t height = latents->ne[1]; + uint32_t dim = latents->ne[ggml_n_dims(latents) - 1]; + + if (preview_mode == PREVIEW_PROJ) { + const float(*latent_rgb_proj)[channel] = nullptr; + float* latent_rgb_bias = nullptr; + + if (dim == 48) { + if (sd_version_is_wan(version)) { + latent_rgb_proj = wan_22_latent_rgb_proj; + latent_rgb_bias = wan_22_latent_rgb_bias; + } else { + LOG_WARN("No latent to RGB projection known for this model"); + // unknown model + return; + } + } else if (dim == 16) { + // 16 channels VAE -> Flux or SD3 + + if (sd_version_is_sd3(version)) { + latent_rgb_proj = sd3_latent_rgb_proj; + latent_rgb_bias = sd3_latent_rgb_bias; + } else if (sd_version_is_flux(version)) { + latent_rgb_proj = flux_latent_rgb_proj; + latent_rgb_bias = flux_latent_rgb_bias; + } else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) { + latent_rgb_proj = wan_21_latent_rgb_proj; + latent_rgb_bias = wan_21_latent_rgb_bias; + } else { + LOG_WARN("No latent to RGB projection known for this model"); + // unknown model + return; + } + + } else if (dim == 4) { + // 4 channels VAE + if (sd_version_is_sdxl(version)) { + latent_rgb_proj = sdxl_latent_rgb_proj; + latent_rgb_bias = sdxl_latent_rgb_bias; + } else if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) { + latent_rgb_proj = sd_latent_rgb_proj; + latent_rgb_bias = sd_latent_rgb_bias; + } else { + // unknown model + LOG_WARN("No latent to RGB projection known for this model"); + return; + } + } else if (dim == 3) { + // Do nothing, assuming already RGB latents + } else { + LOG_WARN("No latent to RGB projection known for this model"); + // unknown latent space + return; + } + + uint32_t frames = 1; + if (ggml_n_dims(latents) == 4) { + frames = latents->ne[2]; + } + + uint8_t* data = (uint8_t*)malloc(frames * width * height * channel * sizeof(uint8_t)); + + preview_latent_video(data, latents, latent_rgb_proj, latent_rgb_bias, width, height, frames, dim); + sd_image_t* images = (sd_image_t*)malloc(frames * sizeof(sd_image_t)); + for (int i = 0; i < frames; i++) { + images[i] = {width, height, channel, data + i * width * height * channel}; + } + step_callback(step, frames, images, is_noisy); + free(data); + free(images); + } else { + if (preview_mode == PREVIEW_VAE) { + process_latent_out(latents); + if (vae_tiling_params.enabled) { + // split latent in 32x32 tiles and compute in several steps + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + first_stage_model->compute(n_threads, in, true, &out, nullptr); + }; + silent_tiling(latents, result, get_vae_scale_factor(), 32, 0.5f, on_tiling); + + } else { + first_stage_model->compute(n_threads, latents, true, &result, work_ctx); + } + + first_stage_model->free_compute_buffer(); + process_vae_output_tensor(result); + process_latent_in(latents); + } else if (preview_mode == PREVIEW_TAE) { + if (tae_first_stage == nullptr) { + LOG_WARN("TAE not found for preview"); + return; + } + if (vae_tiling_params.enabled) { + // split latent in 64x64 tiles and compute in several steps + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + tae_first_stage->compute(n_threads, in, true, &out, nullptr); + }; + silent_tiling(latents, result, get_vae_scale_factor(), 64, 0.5f, on_tiling); + } else { + tae_first_stage->compute(n_threads, latents, true, &result, work_ctx); + } + tae_first_stage->free_compute_buffer(); + } else { + return; + } + + ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f); + uint32_t frames = 1; + if (ggml_n_dims(latents) == 4) { + frames = result->ne[2]; + } + + sd_image_t* images = (sd_image_t*)malloc(frames * sizeof(sd_image_t)); + // print_ggml_tensor(result,true); + for (size_t i = 0; i < frames; i++) { + images[i].width = result->ne[0]; + images[i].height = result->ne[1]; + images[i].channel = 3; + images[i].data = ggml_tensor_to_sd_image(result, i, ggml_n_dims(latents) == 4); + } + + step_callback(step, frames, images, is_noisy); + + ggml_ext_tensor_scale_inplace(result, 0); + for (int i = 0; i < frames; i++) { + free(images[i].data); + } + + free(images); + } + } + ggml_tensor* sample(ggml_context* work_ctx, std::shared_ptr work_diffusion_model, bool inverse_noise_scaling, @@ -1184,7 +1347,34 @@ class StableDiffusionGGML { int64_t t0 = ggml_time_us(); + struct ggml_tensor* preview_tensor = nullptr; + auto sd_preview_mode = sd_get_preview_mode(); + if (sd_preview_mode != PREVIEW_NONE && sd_preview_mode != PREVIEW_PROJ) { + int64_t W = x->ne[0] * get_vae_scale_factor(); + int64_t H = x->ne[1] * get_vae_scale_factor(); + if (ggml_n_dims(x) == 4) { + // assuming video mode (if batch processing gets implemented this will break) + int T = x->ne[2]; + if (sd_version_is_wan(version)) { + T = ((T - 1) * 4) + 1; + } + preview_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, + W, + H, + T, + 3); + } else { + preview_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, + W, + H, + 3, + x->ne[3]); + } + } + auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* { + auto sd_preview_cb = sd_get_preview_callback(); + auto sd_preview_mode = sd_get_preview_mode(); if (step == 1 || step == -1) { pretty_progress(0, (int)steps, 0); } @@ -1219,6 +1409,11 @@ class StableDiffusionGGML { if (denoise_mask != nullptr && version == VERSION_WAN2_2_TI2V) { apply_mask(noised_input, init_latent, denoise_mask); } + if (sd_preview_cb != nullptr && sd_should_preview_noisy()) { + if (step % sd_get_preview_interval() == 0) { + preview_image(work_ctx, step, noised_input, version, sd_preview_mode, preview_tensor, sd_preview_cb, true); + } + } std::vector controls; @@ -1340,16 +1535,22 @@ class StableDiffusionGGML { vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip; } + if (denoise_mask != nullptr) { + apply_mask(denoised, init_latent, denoise_mask); + } + + if (sd_preview_cb != nullptr && sd_should_preview_denoised()) { + if (step % sd_get_preview_interval() == 0) { + preview_image(work_ctx, step, denoised, version, sd_preview_mode, preview_tensor, sd_preview_cb, false); + } + } + int64_t t1 = ggml_time_us(); if (step > 0 || step == -(int)steps) { int showstep = std::abs(step); pretty_progress(showstep, (int)steps, (t1 - t0) / 1000000.f / showstep); // LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000); } - if (denoise_mask != nullptr) { - apply_mask(denoised, init_latent, denoise_mask); - } - return denoised; }; @@ -1855,6 +2056,29 @@ enum prediction_t str_to_prediction(const char* str) { return PREDICTION_COUNT; } +const char* preview_to_str[] = { + "none", + "proj", + "tae", + "vae", +}; + +const char* sd_preview_name(enum preview_t preview) { + if (preview < PREVIEW_COUNT) { + return preview_to_str[preview]; + } + return NONE_STR; +} + +enum preview_t str_to_preview(const char* str) { + for (int i = 0; i < PREVIEW_COUNT; i++) { + if (!strcmp(str, preview_to_str[i])) { + return (enum preview_t)i; + } + } + return PREVIEW_COUNT; +} + void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { *sd_ctx_params = {}; sd_ctx_params->vae_decode_only = true; diff --git a/stable-diffusion.h b/stable-diffusion.h index f618d457..9e99d53d 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -126,6 +126,14 @@ enum sd_log_level_t { SD_LOG_ERROR }; +enum preview_t { + PREVIEW_NONE, + PREVIEW_PROJ, + PREVIEW_TAE, + PREVIEW_VAE, + PREVIEW_COUNT +}; + typedef struct { bool enabled; int tile_size_x; @@ -162,6 +170,7 @@ typedef struct { bool keep_control_net_on_cpu; bool keep_vae_on_cpu; bool diffusion_flash_attn; + bool tae_preview_only; bool diffusion_conv_direct; bool vae_conv_direct; bool force_sdxl_vae_conv_scale; @@ -254,9 +263,11 @@ typedef struct sd_ctx_t sd_ctx_t; typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data); typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data); +typedef void (*sd_preview_cb_t)(int step, int frame_count, sd_image_t* frames, bool is_noisy); SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data); SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data); +SD_API void sd_set_preview_callback(sd_preview_cb_t cb, preview_t mode, int interval, bool denoised, bool noisy); SD_API int32_t get_num_physical_cores(); SD_API const char* sd_get_system_info(); @@ -270,6 +281,8 @@ SD_API const char* sd_schedule_name(enum scheduler_t scheduler); SD_API enum scheduler_t str_to_schedule(const char* str); SD_API const char* sd_prediction_name(enum prediction_t prediction); SD_API enum prediction_t str_to_prediction(const char* str); +SD_API const char* sd_preview_name(enum preview_t preview); +SD_API enum preview_t str_to_preview(const char* str); SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params); SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params); diff --git a/util.cpp b/util.cpp index b205f522..1aa9beff 100644 --- a/util.cpp +++ b/util.cpp @@ -185,6 +185,12 @@ int32_t get_num_physical_cores() { static sd_progress_cb_t sd_progress_cb = nullptr; void* sd_progress_cb_data = nullptr; +static sd_preview_cb_t sd_preview_cb = nullptr; +preview_t sd_preview_mode = PREVIEW_NONE; +int sd_preview_interval = 1; +bool sd_preview_denoised = true; +bool sd_preview_noisy = false; + std::u32string utf8_to_utf32(const std::string& utf8_str) { std::wstring_convert, char32_t> converter; return converter.from_bytes(utf8_str); @@ -328,6 +334,37 @@ void sd_set_progress_callback(sd_progress_cb_t cb, void* data) { sd_progress_cb = cb; sd_progress_cb_data = data; } +void sd_set_preview_callback(sd_preview_cb_t cb, preview_t mode = PREVIEW_PROJ, int interval = 1, bool denoised = true, bool noisy = false) { + sd_preview_cb = cb; + sd_preview_mode = mode; + sd_preview_interval = interval; + sd_preview_denoised = denoised; + sd_preview_noisy = noisy; +} + +sd_preview_cb_t sd_get_preview_callback() { + return sd_preview_cb; +} + +preview_t sd_get_preview_mode() { + return sd_preview_mode; +} +int sd_get_preview_interval() { + return sd_preview_interval; +} +bool sd_should_preview_denoised() { + return sd_preview_denoised; +} +bool sd_should_preview_noisy() { + return sd_preview_noisy; +} + +sd_progress_cb_t sd_get_progress_callback() { + return sd_progress_cb; +} +void* sd_get_progress_callback_data() { + return sd_progress_cb_data; +} const char* sd_get_system_info() { static char buffer[1024]; std::stringstream ss; diff --git a/util.h b/util.h index 17bcd1d3..5bd69a62 100644 --- a/util.h +++ b/util.h @@ -54,6 +54,15 @@ std::string trim(const std::string& s); std::vector> parse_prompt_attention(const std::string& text); +sd_progress_cb_t sd_get_progress_callback(); +void* sd_get_progress_callback_data(); + +sd_preview_cb_t sd_get_preview_callback(); +preview_t sd_get_preview_mode(); +int sd_get_preview_interval(); +bool sd_should_preview_denoised(); +bool sd_should_preview_noisy(); + #define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__) #define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__) #define LOG_WARN(format, ...) log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__) From 694f0d923578262c4f12ae93ded8e56c116085fe Mon Sep 17 00:00:00 2001 From: leejet Date: Mon, 10 Nov 2025 00:12:20 +0800 Subject: [PATCH 7/7] refactor: optimize the logic for name conversion and the processing of the LoRA model (#955) --- conditioner.hpp | 2 +- control.hpp | 2 +- esrgan.hpp | 2 +- flux.hpp | 2 +- ggml_extend.hpp | 6 +- lora.hpp | 1067 ++++++++++++++---------------------------- mmdit.hpp | 2 +- model.cpp | 687 ++------------------------- model.h | 14 +- name_conversion.cpp | 1028 ++++++++++++++++++++++++++++++++++++++++ name_conversion.h | 10 + ordered_map.hpp | 177 +++++++ pmid.hpp | 2 +- qwen_image.hpp | 2 +- qwenvl.hpp | 2 +- stable-diffusion.cpp | 10 +- t5.hpp | 2 +- tae.hpp | 2 +- upscaler.cpp | 2 +- wan.hpp | 4 +- 20 files changed, 1640 insertions(+), 1385 deletions(-) create mode 100644 name_conversion.cpp create mode 100644 name_conversion.h create mode 100644 ordered_map.hpp diff --git a/conditioner.hpp b/conditioner.hpp index b7d80595..93e0c281 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -111,7 +111,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { bool load_embedding(std::string embd_name, std::string embd_path, std::vector& bpe_tokens) { // the order matters ModelLoader model_loader; - if (!model_loader.init_from_file(embd_path)) { + if (!model_loader.init_from_file_and_convert_name(embd_path)) { LOG_ERROR("embedding '%s' failed", embd_name.c_str()); return false; } diff --git a/control.hpp b/control.hpp index 856bde81..b34140ef 100644 --- a/control.hpp +++ b/control.hpp @@ -442,7 +442,7 @@ struct ControlNet : public GGMLRunner { std::set ignore_tensors; ModelLoader model_loader; - if (!model_loader.init_from_file(file_path)) { + if (!model_loader.init_from_file_and_convert_name(file_path)) { LOG_ERROR("init control net model loader from file failed: '%s'", file_path.c_str()); return false; } diff --git a/esrgan.hpp b/esrgan.hpp index dd112439..adce6234 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -169,7 +169,7 @@ struct ESRGAN : public GGMLRunner { LOG_INFO("loading esrgan from '%s'", file_path.c_str()); ModelLoader model_loader; - if (!model_loader.init_from_file(file_path)) { + if (!model_loader.init_from_file_and_convert_name(file_path)) { LOG_ERROR("init esrgan model loader from file failed: '%s'", file_path.c_str()); return false; } diff --git a/flux.hpp b/flux.hpp index 95927f8b..8a255aa1 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1398,7 +1398,7 @@ namespace Flux { ggml_type model_data_type = GGML_TYPE_Q8_0; ModelLoader model_loader; - if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { + if (!model_loader.init_from_file_and_convert_name(file_path, "model.diffusion_model.")) { LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); return; } diff --git a/ggml_extend.hpp b/ggml_extend.hpp index a634bc80..eaf50165 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1568,8 +1568,10 @@ struct GGMLRunner { struct ggml_cgraph* get_compute_graph(get_graph_cb_t get_graph) { prepare_build_in_tensor_before(); struct ggml_cgraph* gf = get_graph(); - auto result = ggml_graph_node(gf, -1); - ggml_set_name(result, final_result_name.c_str()); + if (ggml_graph_n_nodes(gf) > 0) { + auto result = ggml_graph_node(gf, -1); + ggml_set_name(result, final_result_name.c_str()); + } prepare_build_in_tensor_after(gf); return gf; } diff --git a/lora.hpp b/lora.hpp index c5683c3d..6da9d833 100644 --- a/lora.hpp +++ b/lora.hpp @@ -7,91 +7,6 @@ #define LORA_GRAPH_BASE_SIZE 10240 struct LoraModel : public GGMLRunner { - enum lora_t { - REGULAR = 0, - DIFFUSERS = 1, - DIFFUSERS_2 = 2, - DIFFUSERS_3 = 3, - TRANSFORMERS = 4, - LORA_TYPE_COUNT - }; - - const std::string lora_ups[LORA_TYPE_COUNT] = { - ".lora_up", - "_lora.up", - ".lora_B", - ".lora.up", - ".lora_linear_layer.up", - }; - - const std::string lora_downs[LORA_TYPE_COUNT] = { - ".lora_down", - "_lora.down", - ".lora_A", - ".lora.down", - ".lora_linear_layer.down", - }; - - const std::string lora_pre[LORA_TYPE_COUNT] = { - "lora.", - "", - "", - "", - "", - }; - - const std::map alt_names = { - // mmdit - {"final_layer.adaLN_modulation.1", "norm_out.linear"}, - {"pos_embed", "pos_embed.proj"}, - {"final_layer.linear", "proj_out"}, - {"y_embedder.mlp.0", "time_text_embed.text_embedder.linear_1"}, - {"y_embedder.mlp.2", "time_text_embed.text_embedder.linear_2"}, - {"t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1"}, - {"t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2"}, - {"x_block.mlp.fc1", "ff.net.0.proj"}, - {"x_block.mlp.fc2", "ff.net.2"}, - {"context_block.mlp.fc1", "ff_context.net.0.proj"}, - {"context_block.mlp.fc2", "ff_context.net.2"}, - {"x_block.adaLN_modulation.1", "norm1.linear"}, - {"context_block.adaLN_modulation.1", "norm1_context.linear"}, - {"context_block.attn.proj", "attn.to_add_out"}, - {"x_block.attn.proj", "attn.to_out.0"}, - {"x_block.attn2.proj", "attn2.to_out.0"}, - // flux - {"img_in", "x_embedder"}, - // singlestream - {"linear2", "proj_out"}, - {"modulation.lin", "norm.linear"}, - // doublestream - {"txt_attn.proj", "attn.to_add_out"}, - {"img_attn.proj", "attn.to_out.0"}, - {"txt_mlp.0", "ff_context.net.0.proj"}, - {"txt_mlp.2", "ff_context.net.2"}, - {"img_mlp.0", "ff.net.0.proj"}, - {"img_mlp.2", "ff.net.2"}, - {"txt_mod.lin", "norm1_context.linear"}, - {"img_mod.lin", "norm1.linear"}, - }; - - const std::map qkv_prefixes = { - // mmdit - {"context_block.attn.qkv", "attn.add_"}, // suffix "_proj" - {"x_block.attn.qkv", "attn.to_"}, - {"x_block.attn2.qkv", "attn2.to_"}, - // flux - // doublestream - {"txt_attn.qkv", "attn.add_"}, // suffix "_proj" - {"img_attn.qkv", "attn.to_"}, - }; - const std::map qkvm_prefixes = { - // flux - // singlestream - {"linear1", ""}, - }; - - const std::string* type_fingerprints = lora_ups; - float multiplier = 1.0f; std::map lora_tensors; std::map original_tensor_to_final_tensor; @@ -99,15 +14,17 @@ struct LoraModel : public GGMLRunner { ModelLoader model_loader; bool load_failed = false; bool applied = false; + bool tensor_preprocessed = false; std::vector zero_index_vec = {0}; ggml_tensor* zero_index = nullptr; - enum lora_t type = REGULAR; LoraModel(ggml_backend_t backend, const std::string& file_path = "", - const std::string prefix = "") + std::string prefix = "", + SDVersion version = VERSION_COUNT) : file_path(file_path), GGMLRunner(backend, false) { - if (!model_loader.init_from_file(file_path, prefix)) { + prefix = "lora." + prefix; + if (!model_loader.init_from_file_and_convert_name(file_path, prefix, version)) { load_failed = true; } } @@ -131,18 +48,12 @@ struct LoraModel : public GGMLRunner { if (dry_run) { const std::string& name = tensor_storage.name; - if (filter_tensor && !contains(name, "lora")) { + if (filter_tensor && !contains(name, "lora.model")) { return true; } { std::lock_guard lock(lora_mutex); - for (int i = 0; i < LORA_TYPE_COUNT; i++) { - if (name.find(type_fingerprints[i]) != std::string::npos) { - type = (lora_t)i; - break; - } - } tensors_to_create[name] = tensor_storage; } } else { @@ -172,8 +83,6 @@ struct LoraModel : public GGMLRunner { dry_run = false; model_loader.load_tensors(on_new_tensor_cb, n_threads); - LOG_DEBUG("lora type: \"%s\"/\"%s\"", lora_downs[type].c_str(), lora_ups[type].c_str()); - LOG_DEBUG("finished loaded lora"); return true; } @@ -186,669 +95,411 @@ struct LoraModel : public GGMLRunner { return out; } - std::vector to_lora_keys(std::string blk_name, SDVersion version) { - std::vector keys; - // if (!sd_version_is_sd3(version) || blk_name != "model.diffusion_model.pos_embed") { - size_t k_pos = blk_name.find(".weight"); - if (k_pos == std::string::npos) { - return keys; + void preprocess_lora_tensors(const std::map& model_tensors) { + if (tensor_preprocessed) { + return; } - blk_name = blk_name.substr(0, k_pos); - // } - keys.push_back(blk_name); - keys.push_back("lora." + blk_name); - if (sd_version_is_dit(version)) { - if (blk_name.find("model.diffusion_model") != std::string::npos) { - blk_name.replace(blk_name.find("model.diffusion_model"), sizeof("model.diffusion_model") - 1, "transformer"); + tensor_preprocessed = true; + // I really hate these hardcoded processes. + if (model_tensors.find("cond_stage_model.1.transformer.text_model.encoder.layers.0.self_attn.in_proj.weight") != model_tensors.end()) { + std::map new_lora_tensors; + for (auto& [old_name, tensor] : lora_tensors) { + std::string new_name = old_name; + + if (contains(new_name, "cond_stage_model.1.transformer.text_model.encoder.layers")) { + std::vector> qkv_name_map = { + {"self_attn.q_proj.weight", "self_attn.in_proj.weight"}, + {"self_attn.q_proj.bias", "self_attn.in_proj.bias"}, + {"self_attn.k_proj.weight", "self_attn.in_proj.weight.1"}, + {"self_attn.k_proj.bias", "self_attn.in_proj.bias.1"}, + {"self_attn.v_proj.weight", "self_attn.in_proj.weight.2"}, + {"self_attn.v_proj.bias", "self_attn.in_proj.bias.2"}, + }; + for (auto kv : qkv_name_map) { + size_t pos = new_name.find(kv.first); + if (pos != std::string::npos) { + new_name.replace(pos, kv.first.size(), kv.second); + } + } + } + + new_lora_tensors[new_name] = tensor; } - if (blk_name.find(".single_blocks") != std::string::npos) { - blk_name.replace(blk_name.find(".single_blocks"), sizeof(".single_blocks") - 1, ".single_transformer_blocks"); + lora_tensors = std::move(new_lora_tensors); + } + } + + ggml_tensor* get_lora_diff(const std::string& model_tensor_name, std::set& applied_lora_tensors) { + ggml_tensor* updown = nullptr; + int index = 0; + while (true) { + std::string key; + if (index == 0) { + key = model_tensor_name; + } else { + key = model_tensor_name + "." + std::to_string(index); } - if (blk_name.find(".double_blocks") != std::string::npos) { - blk_name.replace(blk_name.find(".double_blocks"), sizeof(".double_blocks") - 1, ".transformer_blocks"); + + std::string lora_down_name = "lora." + key + ".lora_down"; + std::string lora_up_name = "lora." + key + ".lora_up"; + std::string lora_mid_name = "lora." + key + ".lora_mid"; + std::string scale_name = "lora." + key + ".scale"; + std::string alpha_name = "lora." + key + ".alpha"; + + ggml_tensor* lora_up = nullptr; + ggml_tensor* lora_mid = nullptr; + ggml_tensor* lora_down = nullptr; + + auto iter = lora_tensors.find(lora_up_name); + if (iter != lora_tensors.end()) { + lora_up = to_f32(compute_ctx, iter->second); } - if (blk_name.find(".joint_blocks") != std::string::npos) { - blk_name.replace(blk_name.find(".joint_blocks"), sizeof(".joint_blocks") - 1, ".transformer_blocks"); + iter = lora_tensors.find(lora_mid_name); + if (iter != lora_tensors.end()) { + lora_mid = to_f32(compute_ctx, iter->second); } - if (blk_name.find("text_encoders.clip_l") != std::string::npos) { - blk_name.replace(blk_name.find("text_encoders.clip_l"), sizeof("text_encoders.clip_l") - 1, "cond_stage_model"); + iter = lora_tensors.find(lora_down_name); + if (iter != lora_tensors.end()) { + lora_down = to_f32(compute_ctx, iter->second); } - for (const auto& item : alt_names) { - size_t match = blk_name.find(item.first); - if (match != std::string::npos) { - blk_name = blk_name.substr(0, match) + item.second; - } + if (lora_up == nullptr || lora_down == nullptr) { + break; } - for (const auto& prefix : qkv_prefixes) { - size_t match = blk_name.find(prefix.first); - if (match != std::string::npos) { - std::string split_blk = "SPLIT|" + blk_name.substr(0, match) + prefix.second; - keys.push_back(split_blk); - } + + applied_lora_tensors.insert(lora_up_name); + applied_lora_tensors.insert(lora_down_name); + + if (lora_mid) { + applied_lora_tensors.insert(lora_mid_name); } - for (const auto& prefix : qkvm_prefixes) { - size_t match = blk_name.find(prefix.first); - if (match != std::string::npos) { - std::string split_blk = "SPLIT_L|" + blk_name.substr(0, match) + prefix.second; - keys.push_back(split_blk); + + float scale_value = 1.0f; + + int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1]; + iter = lora_tensors.find(scale_name); + if (iter != lora_tensors.end()) { + scale_value = ggml_ext_backend_tensor_get_f32(iter->second); + applied_lora_tensors.insert(scale_name); + } else { + iter = lora_tensors.find(alpha_name); + if (iter != lora_tensors.end()) { + float alpha = ggml_ext_backend_tensor_get_f32(iter->second); + scale_value = alpha / rank; + // LOG_DEBUG("rank %s %ld %.2f %.2f", alpha_name.c_str(), rank, alpha, scale_value); + applied_lora_tensors.insert(alpha_name); } } - keys.push_back(blk_name); + scale_value *= multiplier; + + auto curr_updown = ggml_ext_merge_lora(compute_ctx, lora_down, lora_up, lora_mid); + curr_updown = ggml_scale_inplace(compute_ctx, curr_updown, scale_value); + + if (updown == nullptr) { + updown = curr_updown; + } else { + updown = ggml_concat(compute_ctx, updown, curr_updown, ggml_n_dims(updown) - 1); + } + + index++; } - std::vector ret; - for (std::string& key : keys) { - ret.push_back(key); - replace_all_chars(key, '.', '_'); - // fix for some sdxl lora, like lcm-lora-xl - if (key == "model_diffusion_model_output_blocks_2_2_conv") { - ret.push_back("model_diffusion_model_output_blocks_2_1_conv"); + // diff + if (updown == nullptr) { + std::string lora_diff_name = "lora." + model_tensor_name + ".diff"; + + if (lora_tensors.find(lora_diff_name) != lora_tensors.end()) { + updown = to_f32(compute_ctx, lora_tensors[lora_diff_name]); + applied_lora_tensors.insert(lora_diff_name); } - ret.push_back(key); } - return ret; + + return updown; } - struct ggml_cgraph* build_lora_graph(std::map model_tensors, SDVersion version) { - size_t lora_graph_size = LORA_GRAPH_BASE_SIZE + lora_tensors.size() * 10; - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, lora_graph_size, false); + ggml_tensor* get_loha_diff(const std::string& model_tensor_name, std::set& applied_lora_tensors) { + ggml_tensor* updown = nullptr; + int index = 0; + while (true) { + std::string key; + if (index == 0) { + key = model_tensor_name; + } else { + key = model_tensor_name + "." + std::to_string(index); + } + std::string hada_1_down_name = "lora." + key + ".hada_w1_b"; + std::string hada_1_mid_name = "lora." + key + ".hada_t1"; + std::string hada_1_up_name = "lora." + key + ".hada_w1_a"; + std::string hada_2_down_name = "lora." + key + ".hada_w2_b"; + std::string hada_2_mid_name = "lora." + key + ".hada_t2"; + std::string hada_2_up_name = "lora." + key + ".hada_w2_a"; + std::string alpha_name = "lora." + key + ".alpha"; + + ggml_tensor* hada_1_mid = nullptr; // tau for tucker decomposition + ggml_tensor* hada_1_up = nullptr; + ggml_tensor* hada_1_down = nullptr; + + ggml_tensor* hada_2_mid = nullptr; // tau for tucker decomposition + ggml_tensor* hada_2_up = nullptr; + ggml_tensor* hada_2_down = nullptr; + + auto iter = lora_tensors.find(hada_1_down_name); + if (iter != lora_tensors.end()) { + hada_1_down = to_f32(compute_ctx, iter->second); + } - zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1); - set_backend_tensor_data(zero_index, zero_index_vec.data()); - ggml_build_forward_expand(gf, zero_index); + iter = lora_tensors.find(hada_1_up_name); + if (iter != lora_tensors.end()) { + hada_1_up = to_f32(compute_ctx, iter->second); + } - original_tensor_to_final_tensor.clear(); + iter = lora_tensors.find(hada_1_mid_name); + if (iter != lora_tensors.end()) { + hada_1_mid = to_f32(compute_ctx, iter->second); + hada_1_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_1_up)); + } - std::set applied_lora_tensors; - for (auto it : model_tensors) { - std::string model_tensor_name = it.first; - struct ggml_tensor* model_tensor = model_tensors[it.first]; - - std::vector keys = to_lora_keys(model_tensor_name, version); - bool is_bias = ends_with(model_tensor_name, ".bias"); - if (keys.size() == 0) { - if (is_bias) { - keys.push_back(model_tensor_name.substr(0, model_tensor_name.size() - 5)); // remove .bias - } else { - continue; - } + iter = lora_tensors.find(hada_2_down_name); + if (iter != lora_tensors.end()) { + hada_2_down = to_f32(compute_ctx, iter->second); } - for (auto& key : keys) { - bool is_qkv_split = starts_with(key, "SPLIT|"); - if (is_qkv_split) { - key = key.substr(sizeof("SPLIT|") - 1); - } - bool is_qkvm_split = starts_with(key, "SPLIT_L|"); - if (is_qkvm_split) { - key = key.substr(sizeof("SPLIT_L|") - 1); + iter = lora_tensors.find(hada_2_up_name); + if (iter != lora_tensors.end()) { + hada_2_up = to_f32(compute_ctx, iter->second); + } + + iter = lora_tensors.find(hada_2_mid_name); + if (iter != lora_tensors.end()) { + hada_2_mid = to_f32(compute_ctx, iter->second); + hada_2_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_2_up)); + } + + if (hada_1_up == nullptr || hada_1_down == nullptr || hada_2_up == nullptr || hada_2_down == nullptr) { + break; + } + + applied_lora_tensors.insert(hada_1_down_name); + applied_lora_tensors.insert(hada_1_up_name); + applied_lora_tensors.insert(hada_2_down_name); + applied_lora_tensors.insert(hada_2_up_name); + applied_lora_tensors.insert(alpha_name); + + if (hada_1_mid) { + applied_lora_tensors.insert(hada_1_mid_name); + } + + if (hada_2_mid) { + applied_lora_tensors.insert(hada_2_mid_name); + } + + float scale_value = 1.0f; + + // calc_scale + // TODO: .dora_scale? + int64_t rank = hada_1_down->ne[ggml_n_dims(hada_1_down) - 1]; + iter = lora_tensors.find(alpha_name); + if (iter != lora_tensors.end()) { + float alpha = ggml_ext_backend_tensor_get_f32(iter->second); + scale_value = alpha / rank; + applied_lora_tensors.insert(alpha_name); + } + scale_value *= multiplier; + + struct ggml_tensor* updown_1 = ggml_ext_merge_lora(compute_ctx, hada_1_down, hada_1_up, hada_1_mid); + struct ggml_tensor* updown_2 = ggml_ext_merge_lora(compute_ctx, hada_2_down, hada_2_up, hada_2_mid); + auto curr_updown = ggml_mul_inplace(compute_ctx, updown_1, updown_2); + curr_updown = ggml_scale_inplace(compute_ctx, curr_updown, scale_value); + if (updown == nullptr) { + updown = curr_updown; + } else { + updown = ggml_concat(compute_ctx, updown, curr_updown, ggml_n_dims(updown) - 1); + } + index++; + } + return updown; + } + + ggml_tensor* get_lokr_diff(const std::string& model_tensor_name, std::set& applied_lora_tensors) { + ggml_tensor* updown = nullptr; + int index = 0; + while (true) { + std::string key; + if (index == 0) { + key = model_tensor_name; + } else { + key = model_tensor_name + "." + std::to_string(index); + } + std::string lokr_w1_name = "lora." + key + ".lokr_w1"; + std::string lokr_w1_a_name = "lora." + key + ".lokr_w1_a"; + std::string lokr_w1_b_name = "lora." + key + ".lokr_w1_b"; + std::string lokr_w2_name = "lora." + key + ".lokr_w2"; + std::string lokr_w2_a_name = "lora." + key + ".lokr_w2_a"; + std::string lokr_w2_b_name = "lora." + key + ".lokr_w2_b"; + std::string alpha_name = "lora." + key + ".alpha"; + + ggml_tensor* lokr_w1 = nullptr; + ggml_tensor* lokr_w1_a = nullptr; + ggml_tensor* lokr_w1_b = nullptr; + ggml_tensor* lokr_w2 = nullptr; + ggml_tensor* lokr_w2_a = nullptr; + ggml_tensor* lokr_w2_b = nullptr; + + auto iter = lora_tensors.find(lokr_w1_name); + if (iter != lora_tensors.end()) { + lokr_w1 = to_f32(compute_ctx, iter->second); + } + + iter = lora_tensors.find(lokr_w2_name); + if (iter != lora_tensors.end()) { + lokr_w2 = to_f32(compute_ctx, iter->second); + } + + int64_t rank = 1; + if (lokr_w1 == nullptr) { + iter = lora_tensors.find(lokr_w1_a_name); + if (iter != lora_tensors.end()) { + lokr_w1_a = to_f32(compute_ctx, iter->second); } - struct ggml_tensor* updown = nullptr; - float scale_value = 1.0f; - std::string full_key = lora_pre[type] + key; - if (is_bias) { - if (lora_tensors.find(full_key + ".diff_b") != lora_tensors.end()) { - std::string diff_name = full_key + ".diff_b"; - ggml_tensor* diff = lora_tensors[diff_name]; - updown = to_f32(compute_ctx, diff); - applied_lora_tensors.insert(diff_name); - } else { - continue; - } - } else if (lora_tensors.find(full_key + ".diff") != lora_tensors.end()) { - std::string diff_name = full_key + ".diff"; - ggml_tensor* diff = lora_tensors[diff_name]; - updown = to_f32(compute_ctx, diff); - applied_lora_tensors.insert(diff_name); - } else if (lora_tensors.find(full_key + ".hada_w1_a") != lora_tensors.end()) { - // LoHa mode - - // TODO: split qkv convention for LoHas (is it ever used?) - if (is_qkv_split || is_qkvm_split) { - LOG_ERROR("Split qkv isn't supported for LoHa models."); - break; - } - std::string alpha_name = ""; - ggml_tensor* hada_1_mid = nullptr; // tau for tucker decomposition - ggml_tensor* hada_1_up = nullptr; - ggml_tensor* hada_1_down = nullptr; + iter = lora_tensors.find(lokr_w1_b_name); + if (iter != lora_tensors.end()) { + lokr_w1_b = to_f32(compute_ctx, iter->second); + } - ggml_tensor* hada_2_mid = nullptr; // tau for tucker decomposition - ggml_tensor* hada_2_up = nullptr; - ggml_tensor* hada_2_down = nullptr; + if (lokr_w1_a == nullptr || lokr_w1_b == nullptr) { + break; + } - std::string hada_1_mid_name = ""; - std::string hada_1_down_name = ""; - std::string hada_1_up_name = ""; + rank = lokr_w1_b->ne[ggml_n_dims(lokr_w1_b) - 1]; - std::string hada_2_mid_name = ""; - std::string hada_2_down_name = ""; - std::string hada_2_up_name = ""; + lokr_w1 = ggml_ext_merge_lora(compute_ctx, lokr_w1_b, lokr_w1_a); + } - hada_1_down_name = full_key + ".hada_w1_b"; - hada_1_up_name = full_key + ".hada_w1_a"; - hada_1_mid_name = full_key + ".hada_t1"; - if (lora_tensors.find(hada_1_down_name) != lora_tensors.end()) { - hada_1_down = to_f32(compute_ctx, lora_tensors[hada_1_down_name]); - } - if (lora_tensors.find(hada_1_up_name) != lora_tensors.end()) { - hada_1_up = to_f32(compute_ctx, lora_tensors[hada_1_up_name]); - } - if (lora_tensors.find(hada_1_mid_name) != lora_tensors.end()) { - hada_1_mid = to_f32(compute_ctx, lora_tensors[hada_1_mid_name]); - applied_lora_tensors.insert(hada_1_mid_name); - hada_1_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_1_up)); - } + if (lokr_w2 == nullptr) { + iter = lora_tensors.find(lokr_w2_a_name); + if (iter != lora_tensors.end()) { + lokr_w2_a = to_f32(compute_ctx, iter->second); + } - hada_2_down_name = full_key + ".hada_w2_b"; - hada_2_up_name = full_key + ".hada_w2_a"; - hada_2_mid_name = full_key + ".hada_t2"; - if (lora_tensors.find(hada_2_down_name) != lora_tensors.end()) { - hada_2_down = to_f32(compute_ctx, lora_tensors[hada_2_down_name]); - } - if (lora_tensors.find(hada_2_up_name) != lora_tensors.end()) { - hada_2_up = to_f32(compute_ctx, lora_tensors[hada_2_up_name]); - } - if (lora_tensors.find(hada_2_mid_name) != lora_tensors.end()) { - hada_2_mid = to_f32(compute_ctx, lora_tensors[hada_2_mid_name]); - applied_lora_tensors.insert(hada_2_mid_name); - hada_2_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_2_up)); - } + iter = lora_tensors.find(lokr_w2_b_name); + if (iter != lora_tensors.end()) { + lokr_w2_b = to_f32(compute_ctx, iter->second); + } - alpha_name = full_key + ".alpha"; + if (lokr_w2_a == nullptr || lokr_w2_b == nullptr) { + break; + } - applied_lora_tensors.insert(hada_1_down_name); - applied_lora_tensors.insert(hada_1_up_name); - applied_lora_tensors.insert(hada_2_down_name); - applied_lora_tensors.insert(hada_2_up_name); + rank = lokr_w2_b->ne[ggml_n_dims(lokr_w2_b) - 1]; - applied_lora_tensors.insert(alpha_name); - if (hada_1_up == nullptr || hada_1_down == nullptr || hada_2_up == nullptr || hada_2_down == nullptr) { - continue; - } + lokr_w2 = ggml_ext_merge_lora(compute_ctx, lokr_w2_b, lokr_w2_a); + } - struct ggml_tensor* updown_1 = ggml_ext_merge_lora(compute_ctx, hada_1_down, hada_1_up, hada_1_mid); - struct ggml_tensor* updown_2 = ggml_ext_merge_lora(compute_ctx, hada_2_down, hada_2_up, hada_2_mid); - updown = ggml_mul_inplace(compute_ctx, updown_1, updown_2); + if (!lokr_w1_a) { + applied_lora_tensors.insert(lokr_w1_name); + } else { + applied_lora_tensors.insert(lokr_w1_a_name); + applied_lora_tensors.insert(lokr_w1_b_name); + } - // calc_scale - // TODO: .dora_scale? - int64_t rank = hada_1_down->ne[ggml_n_dims(hada_1_down) - 1]; - if (lora_tensors.find(alpha_name) != lora_tensors.end()) { - float alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[alpha_name]); - scale_value = alpha / rank; - } - } else if (lora_tensors.find(full_key + ".lokr_w1") != lora_tensors.end() || lora_tensors.find(full_key + ".lokr_w1_a") != lora_tensors.end()) { - // LoKr mode + if (!lokr_w2_a) { + applied_lora_tensors.insert(lokr_w2_name); + } else { + applied_lora_tensors.insert(lokr_w2_a_name); + applied_lora_tensors.insert(lokr_w2_b_name); + } - // TODO: split qkv convention for LoKrs (is it ever used?) - if (is_qkv_split || is_qkvm_split) { - LOG_ERROR("Split qkv isn't supported for LoKr models."); - break; - } + float scale_value = 1.0f; + iter = lora_tensors.find(alpha_name); + if (iter != lora_tensors.end()) { + float alpha = ggml_ext_backend_tensor_get_f32(iter->second); + scale_value = alpha / rank; + applied_lora_tensors.insert(alpha_name); + } - std::string alpha_name = full_key + ".alpha"; - - ggml_tensor* lokr_w1 = nullptr; - ggml_tensor* lokr_w2 = nullptr; - - std::string lokr_w1_name = ""; - std::string lokr_w2_name = ""; - - lokr_w1_name = full_key + ".lokr_w1"; - lokr_w2_name = full_key + ".lokr_w2"; - - if (lora_tensors.find(lokr_w1_name) != lora_tensors.end()) { - lokr_w1 = to_f32(compute_ctx, lora_tensors[lokr_w1_name]); - applied_lora_tensors.insert(lokr_w1_name); - } else { - ggml_tensor* down = nullptr; - ggml_tensor* up = nullptr; - std::string down_name = lokr_w1_name + "_b"; - std::string up_name = lokr_w1_name + "_a"; - if (lora_tensors.find(down_name) != lora_tensors.end()) { - // w1 should not be low rank normally, sometimes w1 and w2 are swapped - down = to_f32(compute_ctx, lora_tensors[down_name]); - applied_lora_tensors.insert(down_name); - - int64_t rank = down->ne[ggml_n_dims(down) - 1]; - if (lora_tensors.find(alpha_name) != lora_tensors.end()) { - float alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[alpha_name]); - scale_value = alpha / rank; - } - } - if (lora_tensors.find(up_name) != lora_tensors.end()) { - up = to_f32(compute_ctx, lora_tensors[up_name]); - applied_lora_tensors.insert(up_name); - } - lokr_w1 = ggml_ext_merge_lora(compute_ctx, down, up); - } - if (lora_tensors.find(lokr_w2_name) != lora_tensors.end()) { - lokr_w2 = to_f32(compute_ctx, lora_tensors[lokr_w2_name]); - applied_lora_tensors.insert(lokr_w2_name); - } else { - ggml_tensor* down = nullptr; - ggml_tensor* up = nullptr; - std::string down_name = lokr_w2_name + "_b"; - std::string up_name = lokr_w2_name + "_a"; - if (lora_tensors.find(down_name) != lora_tensors.end()) { - down = to_f32(compute_ctx, lora_tensors[down_name]); - applied_lora_tensors.insert(down_name); - - int64_t rank = down->ne[ggml_n_dims(down) - 1]; - if (lora_tensors.find(alpha_name) != lora_tensors.end()) { - float alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[alpha_name]); - scale_value = alpha / rank; - } - } - if (lora_tensors.find(up_name) != lora_tensors.end()) { - up = to_f32(compute_ctx, lora_tensors[up_name]); - applied_lora_tensors.insert(up_name); - } - lokr_w2 = ggml_ext_merge_lora(compute_ctx, down, up); - } + if (rank == 1) { + scale_value = 1.0f; + } - // Technically it might be unused, but I believe it's the expected behavior - applied_lora_tensors.insert(alpha_name); + scale_value *= multiplier; - updown = ggml_ext_kronecker(compute_ctx, lokr_w1, lokr_w2); + auto curr_updown = ggml_ext_kronecker(compute_ctx, lokr_w1, lokr_w2); + curr_updown = ggml_scale_inplace(compute_ctx, curr_updown, scale_value); - } else { - // LoRA mode - ggml_tensor* lora_mid = nullptr; // tau for tucker decomposition - ggml_tensor* lora_up = nullptr; - ggml_tensor* lora_down = nullptr; + if (updown == nullptr) { + updown = curr_updown; + } else { + updown = ggml_concat(compute_ctx, updown, curr_updown, ggml_n_dims(updown) - 1); + } + index++; + } + return updown; + } - std::string alpha_name = ""; - std::string scale_name = ""; - std::string split_q_scale_name = ""; - std::string lora_mid_name = ""; - std::string lora_down_name = ""; - std::string lora_up_name = ""; + struct ggml_cgraph* build_lora_graph(const std::map& model_tensors, SDVersion version) { + size_t lora_graph_size = LORA_GRAPH_BASE_SIZE + lora_tensors.size() * 10; + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, lora_graph_size, false); - if (is_qkv_split) { - std::string suffix = ""; - auto split_q_d_name = full_key + "q" + suffix + lora_downs[type] + ".weight"; + zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1); + set_backend_tensor_data(zero_index, zero_index_vec.data()); + ggml_build_forward_expand(gf, zero_index); - if (lora_tensors.find(split_q_d_name) == lora_tensors.end()) { - suffix = "_proj"; - split_q_d_name = full_key + "q" + suffix + lora_downs[type] + ".weight"; - } - if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { - // print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1] - // find qkv and mlp up parts in LoRA model - auto split_k_d_name = full_key + "k" + suffix + lora_downs[type] + ".weight"; - auto split_v_d_name = full_key + "v" + suffix + lora_downs[type] + ".weight"; - - auto split_q_u_name = full_key + "q" + suffix + lora_ups[type] + ".weight"; - auto split_k_u_name = full_key + "k" + suffix + lora_ups[type] + ".weight"; - auto split_v_u_name = full_key + "v" + suffix + lora_ups[type] + ".weight"; - - auto split_q_scale_name = full_key + "q" + suffix + ".scale"; - auto split_k_scale_name = full_key + "k" + suffix + ".scale"; - auto split_v_scale_name = full_key + "v" + suffix + ".scale"; - - auto split_q_alpha_name = full_key + "q" + suffix + ".alpha"; - auto split_k_alpha_name = full_key + "k" + suffix + ".alpha"; - auto split_v_alpha_name = full_key + "v" + suffix + ".alpha"; - - ggml_tensor* lora_q_down = nullptr; - ggml_tensor* lora_q_up = nullptr; - ggml_tensor* lora_k_down = nullptr; - ggml_tensor* lora_k_up = nullptr; - ggml_tensor* lora_v_down = nullptr; - ggml_tensor* lora_v_up = nullptr; - - lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]); - - if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) { - lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); - } - - if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) { - lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]); - } - - if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) { - lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]); - } - - if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) { - lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]); - } - - if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) { - lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]); - } - - float q_rank = lora_q_up->ne[0]; - float k_rank = lora_k_up->ne[0]; - float v_rank = lora_v_up->ne[0]; - - float lora_q_scale = 1; - float lora_k_scale = 1; - float lora_v_scale = 1; - - if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) { - lora_q_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_q_scale_name]); - applied_lora_tensors.insert(split_q_scale_name); - } - if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) { - lora_k_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_k_scale_name]); - applied_lora_tensors.insert(split_k_scale_name); - } - if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) { - lora_v_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_v_scale_name]); - applied_lora_tensors.insert(split_v_scale_name); - } - - if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) { - float lora_q_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]); - applied_lora_tensors.insert(split_q_alpha_name); - lora_q_scale = lora_q_alpha / q_rank; - } - if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) { - float lora_k_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]); - applied_lora_tensors.insert(split_k_alpha_name); - lora_k_scale = lora_k_alpha / k_rank; - } - if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) { - float lora_v_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]); - applied_lora_tensors.insert(split_v_alpha_name); - lora_v_scale = lora_v_alpha / v_rank; - } - - ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale); - ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale); - ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale); - - // print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1] - // print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1] - // print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1] - - // these need to be stitched together this way: - // |q_up,0 ,0 | - // |0 ,k_up,0 | - // |0 ,0 ,v_up| - // (q_down,k_down,v_down) . (q ,k ,v) - - // up_concat will be [9216, R*3, 1, 1] - // down_concat will be [R*3, 3072, 1, 1] - ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), lora_v_down, 1); - - ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up); - ggml_scale(compute_ctx, z, 0); - ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1); - - ggml_tensor* q_up = ggml_concat(compute_ctx, lora_q_up, zz, 1); - ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), z, 1); - ggml_tensor* v_up = ggml_concat(compute_ctx, zz, lora_v_up, 1); - // print_ggml_tensor(q_up, true); //[R, 9216, 1, 1] - // print_ggml_tensor(k_up, true); //[R, 9216, 1, 1] - // print_ggml_tensor(v_up, true); //[R, 9216, 1, 1] - ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), v_up, 0); - // print_ggml_tensor(lora_up_concat, true); //[R*3, 9216, 1, 1] - - lora_down = ggml_cont(compute_ctx, lora_down_concat); - lora_up = ggml_cont(compute_ctx, lora_up_concat); - - applied_lora_tensors.insert(split_q_u_name); - applied_lora_tensors.insert(split_k_u_name); - applied_lora_tensors.insert(split_v_u_name); - - applied_lora_tensors.insert(split_q_d_name); - applied_lora_tensors.insert(split_k_d_name); - applied_lora_tensors.insert(split_v_d_name); - } - } else if (is_qkvm_split) { - auto split_q_d_name = full_key + "attn.to_q" + lora_downs[type] + ".weight"; - if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { - // print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1] - // find qkv and mlp up parts in LoRA model - auto split_k_d_name = full_key + "attn.to_k" + lora_downs[type] + ".weight"; - auto split_v_d_name = full_key + "attn.to_v" + lora_downs[type] + ".weight"; - - auto split_q_u_name = full_key + "attn.to_q" + lora_ups[type] + ".weight"; - auto split_k_u_name = full_key + "attn.to_k" + lora_ups[type] + ".weight"; - auto split_v_u_name = full_key + "attn.to_v" + lora_ups[type] + ".weight"; - - auto split_m_d_name = full_key + "proj_mlp" + lora_downs[type] + ".weight"; - auto split_m_u_name = full_key + "proj_mlp" + lora_ups[type] + ".weight"; - - auto split_q_scale_name = full_key + "attn.to_q" + ".scale"; - auto split_k_scale_name = full_key + "attn.to_k" + ".scale"; - auto split_v_scale_name = full_key + "attn.to_v" + ".scale"; - auto split_m_scale_name = full_key + "proj_mlp" + ".scale"; - - auto split_q_alpha_name = full_key + "attn.to_q" + ".alpha"; - auto split_k_alpha_name = full_key + "attn.to_k" + ".alpha"; - auto split_v_alpha_name = full_key + "attn.to_v" + ".alpha"; - auto split_m_alpha_name = full_key + "proj_mlp" + ".alpha"; - - ggml_tensor* lora_q_down = nullptr; - ggml_tensor* lora_q_up = nullptr; - ggml_tensor* lora_k_down = nullptr; - ggml_tensor* lora_k_up = nullptr; - ggml_tensor* lora_v_down = nullptr; - ggml_tensor* lora_v_up = nullptr; - - ggml_tensor* lora_m_down = nullptr; - ggml_tensor* lora_m_up = nullptr; - - lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); - - if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { - lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]); - } - - if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) { - lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); - } - - if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) { - lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]); - } - - if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) { - lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]); - } - - if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) { - lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]); - } - - if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) { - lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]); - } - - if (lora_tensors.find(split_m_d_name) != lora_tensors.end()) { - lora_m_down = to_f32(compute_ctx, lora_tensors[split_m_d_name]); - } - - if (lora_tensors.find(split_m_u_name) != lora_tensors.end()) { - lora_m_up = to_f32(compute_ctx, lora_tensors[split_m_u_name]); - } - - float q_rank = lora_q_up->ne[0]; - float k_rank = lora_k_up->ne[0]; - float v_rank = lora_v_up->ne[0]; - float m_rank = lora_v_up->ne[0]; - - float lora_q_scale = 1; - float lora_k_scale = 1; - float lora_v_scale = 1; - float lora_m_scale = 1; - - if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) { - lora_q_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_q_scale_name]); - applied_lora_tensors.insert(split_q_scale_name); - } - if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) { - lora_k_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_k_scale_name]); - applied_lora_tensors.insert(split_k_scale_name); - } - if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) { - lora_v_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_v_scale_name]); - applied_lora_tensors.insert(split_v_scale_name); - } - if (lora_tensors.find(split_m_scale_name) != lora_tensors.end()) { - lora_m_scale = ggml_ext_backend_tensor_get_f32(lora_tensors[split_m_scale_name]); - applied_lora_tensors.insert(split_m_scale_name); - } - - if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) { - float lora_q_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]); - applied_lora_tensors.insert(split_q_alpha_name); - lora_q_scale = lora_q_alpha / q_rank; - } - if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) { - float lora_k_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]); - applied_lora_tensors.insert(split_k_alpha_name); - lora_k_scale = lora_k_alpha / k_rank; - } - if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) { - float lora_v_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]); - applied_lora_tensors.insert(split_v_alpha_name); - lora_v_scale = lora_v_alpha / v_rank; - } - if (lora_tensors.find(split_m_alpha_name) != lora_tensors.end()) { - float lora_m_alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[split_m_alpha_name]); - applied_lora_tensors.insert(split_m_alpha_name); - lora_m_scale = lora_m_alpha / m_rank; - } - - ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale); - ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale); - ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale); - ggml_scale_inplace(compute_ctx, lora_m_down, lora_m_scale); - - // print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_m_down, true); //[3072, R, 1, 1] - // print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1] - // print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1] - // print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1] - // print_ggml_tensor(lora_m_up, true); //[R, 12288, 1, 1] - - // these need to be stitched together this way: - // |q_up,0 ,0 ,0 | - // |0 ,k_up,0 ,0 | - // |0 ,0 ,v_up,0 | - // |0 ,0 ,0 ,m_up| - // (q_down,k_down,v_down,m_down) . (q ,k ,v ,m) - - // up_concat will be [21504, R*4, 1, 1] - // down_concat will be [R*4, 3072, 1, 1] - - ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), ggml_concat(compute_ctx, lora_v_down, lora_m_down, 1), 1); - // print_ggml_tensor(lora_down_concat, true); //[3072, R*4, 1, 1] - - // this also means that if rank is bigger than 672, it is less memory efficient to do it this way (should be fine) - // print_ggml_tensor(lora_q_up, true); //[3072, R, 1, 1] - ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up); - ggml_tensor* mlp_z = ggml_dup_tensor(compute_ctx, lora_m_up); - ggml_scale(compute_ctx, z, 0); - ggml_scale(compute_ctx, mlp_z, 0); - ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1); - - ggml_tensor* q_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_up, zz, 1), mlp_z, 1); - ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), ggml_concat(compute_ctx, z, mlp_z, 1), 1); - ggml_tensor* v_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, lora_v_up, 1), mlp_z, 1); - ggml_tensor* m_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, z, 1), lora_m_up, 1); - // print_ggml_tensor(q_up, true); //[R, 21504, 1, 1] - // print_ggml_tensor(k_up, true); //[R, 21504, 1, 1] - // print_ggml_tensor(v_up, true); //[R, 21504, 1, 1] - // print_ggml_tensor(m_up, true); //[R, 21504, 1, 1] - - ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), ggml_concat(compute_ctx, v_up, m_up, 0), 0); - // print_ggml_tensor(lora_up_concat, true); //[R*4, 21504, 1, 1] - - lora_down = ggml_cont(compute_ctx, lora_down_concat); - lora_up = ggml_cont(compute_ctx, lora_up_concat); - - applied_lora_tensors.insert(split_q_u_name); - applied_lora_tensors.insert(split_k_u_name); - applied_lora_tensors.insert(split_v_u_name); - applied_lora_tensors.insert(split_m_u_name); - - applied_lora_tensors.insert(split_q_d_name); - applied_lora_tensors.insert(split_k_d_name); - applied_lora_tensors.insert(split_v_d_name); - applied_lora_tensors.insert(split_m_d_name); - } - } else { - lora_up_name = full_key + lora_ups[type] + ".weight"; - lora_down_name = full_key + lora_downs[type] + ".weight"; - lora_mid_name = full_key + ".lora_mid.weight"; + preprocess_lora_tensors(model_tensors); - alpha_name = full_key + ".alpha"; - scale_name = full_key + ".scale"; + original_tensor_to_final_tensor.clear(); - if (lora_tensors.find(lora_up_name) != lora_tensors.end()) { - lora_up = to_f32(compute_ctx, lora_tensors[lora_up_name]); - applied_lora_tensors.insert(lora_up_name); - } + std::set applied_lora_tensors; + for (auto it : model_tensors) { + std::string model_tensor_name = it.first; + ggml_tensor* model_tensor = it.second; + + // lora + ggml_tensor* updown = get_lora_diff(model_tensor_name, applied_lora_tensors); + // loha + if (updown == nullptr) { + updown = get_loha_diff(model_tensor_name, applied_lora_tensors); + } - if (lora_tensors.find(lora_down_name) != lora_tensors.end()) { - lora_down = to_f32(compute_ctx, lora_tensors[lora_down_name]); - applied_lora_tensors.insert(lora_down_name); - } + // lokr + if (updown == nullptr) { + updown = get_lokr_diff(model_tensor_name, applied_lora_tensors); + } - if (lora_tensors.find(lora_mid_name) != lora_tensors.end()) { - lora_mid = to_f32(compute_ctx, lora_tensors[lora_mid_name]); - applied_lora_tensors.insert(lora_mid_name); - } - } + if (updown == nullptr) { + continue; + } - if (lora_up == nullptr || lora_down == nullptr) { - continue; - } - // calc_scale - // TODO: .dora_scale? - int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1]; - if (lora_tensors.find(scale_name) != lora_tensors.end()) { - scale_value = ggml_ext_backend_tensor_get_f32(lora_tensors[scale_name]); - applied_lora_tensors.insert(scale_name); - } else if (lora_tensors.find(alpha_name) != lora_tensors.end()) { - float alpha = ggml_ext_backend_tensor_get_f32(lora_tensors[alpha_name]); - scale_value = alpha / rank; - // LOG_DEBUG("rank %s %ld %.2f %.2f", alpha_name.c_str(), rank, alpha, scale_value); - applied_lora_tensors.insert(alpha_name); - } + ggml_tensor* original_tensor = model_tensor; + if (!ggml_backend_is_cpu(runtime_backend) && ggml_backend_buffer_is_host(original_tensor->buffer)) { + model_tensor = ggml_dup_tensor(compute_ctx, model_tensor); + set_backend_tensor_data(model_tensor, original_tensor->data); + } - updown = ggml_ext_merge_lora(compute_ctx, lora_down, lora_up, lora_mid); + if (ggml_nelements(updown) < ggml_nelements(model_tensor)) { + if (ggml_n_dims(updown) == 2 && ggml_n_dims(model_tensor) == 2 && updown->ne[0] == model_tensor->ne[0]) { + LOG_WARN("pad for %s", model_tensor_name.c_str()); + auto pad_tensor = ggml_ext_zeros(compute_ctx, updown->ne[0], model_tensor->ne[1] - updown->ne[1], 1, 1); + updown = ggml_concat(compute_ctx, updown, pad_tensor, 1); } - scale_value *= multiplier; - ggml_tensor* original_tensor = model_tensor; - if (!ggml_backend_is_cpu(runtime_backend) && ggml_backend_buffer_is_host(original_tensor->buffer)) { - model_tensor = ggml_dup_tensor(compute_ctx, model_tensor); - set_backend_tensor_data(model_tensor, original_tensor->data); - } - updown = ggml_reshape(compute_ctx, updown, model_tensor); - GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(model_tensor)); - updown = ggml_scale_inplace(compute_ctx, updown, scale_value); - ggml_tensor* final_tensor; - if (model_tensor->type != GGML_TYPE_F32 && model_tensor->type != GGML_TYPE_F16) { - final_tensor = to_f32(compute_ctx, model_tensor); - final_tensor = ggml_add_inplace(compute_ctx, final_tensor, updown); - final_tensor = ggml_cpy(compute_ctx, final_tensor, model_tensor); - } else { - final_tensor = ggml_add_inplace(compute_ctx, model_tensor, updown); - } - ggml_build_forward_expand(gf, final_tensor); - if (!ggml_backend_is_cpu(runtime_backend) && ggml_backend_buffer_is_host(original_tensor->buffer)) { - original_tensor_to_final_tensor[original_tensor] = final_tensor; - } - break; + } + + GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(model_tensor)); + updown = ggml_reshape(compute_ctx, updown, model_tensor); + ggml_tensor* final_tensor; + if (model_tensor->type != GGML_TYPE_F32 && model_tensor->type != GGML_TYPE_F16) { + final_tensor = to_f32(compute_ctx, model_tensor); + final_tensor = ggml_add_inplace(compute_ctx, final_tensor, updown); + final_tensor = ggml_cpy(compute_ctx, final_tensor, model_tensor); + } else { + final_tensor = ggml_add_inplace(compute_ctx, model_tensor, updown); + } + ggml_build_forward_expand(gf, final_tensor); + if (!ggml_backend_is_cpu(runtime_backend) && ggml_backend_buffer_is_host(original_tensor->buffer)) { + original_tensor_to_final_tensor[original_tensor] = final_tensor; } } size_t total_lora_tensors_count = 0; diff --git a/mmdit.hpp b/mmdit.hpp index 7249a13e..3ca01d95 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -961,7 +961,7 @@ struct MMDiTRunner : public GGMLRunner { mmdit->get_param_tensors(tensors, "model.diffusion_model"); ModelLoader model_loader; - if (!model_loader.init_from_file(file_path)) { + if (!model_loader.init_from_file_and_convert_name(file_path)) { LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); return; } diff --git a/model.cpp b/model.cpp index 79fde749..519284e6 100644 --- a/model.cpp +++ b/model.cpp @@ -25,6 +25,7 @@ #include "ggml-cpu.h" #include "ggml.h" +#include "name_conversion.h" #include "stable-diffusion.h" #ifdef SD_USE_METAL @@ -75,15 +76,6 @@ uint16_t read_short(uint8_t* buffer) { /*================================================= Preprocess ==================================================*/ -std::string self_attn_names[] = { - "self_attn.q_proj.weight", - "self_attn.k_proj.weight", - "self_attn.v_proj.weight", - "self_attn.q_proj.bias", - "self_attn.k_proj.bias", - "self_attn.v_proj.bias", -}; - const char* unused_tensors[] = { "betas", "alphas_cumprod_prev", @@ -97,9 +89,9 @@ const char* unused_tensors[] = { "posterior_mean_coef1", "posterior_mean_coef2", "cond_stage_model.transformer.text_model.embeddings.position_ids", + "cond_stage_model.1.model.text_model.embeddings.position_ids", "cond_stage_model.transformer.vision_model.embeddings.position_ids", "cond_stage_model.model.logit_scale", - "cond_stage_model.model.text_projection", "conditioner.embedders.0.transformer.text_model.embeddings.position_ids", "conditioner.embedders.0.model.logit_scale", "conditioner.embedders.1.model.logit_scale", @@ -110,6 +102,7 @@ const char* unused_tensors[] = { "model_ema.diffusion_model", "embedding_manager", "denoiser.sigmas", + "edm_vpred.sigma_max", "text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training "text_encoders.qwen2vl.output.weight", "text_encoders.qwen2vl.lm_head.", @@ -124,622 +117,6 @@ bool is_unused_tensor(std::string name) { return false; } -std::unordered_map open_clip_to_hf_clip_model = { - {"model.ln_final.bias", "transformer.text_model.final_layer_norm.bias"}, - {"model.ln_final.weight", "transformer.text_model.final_layer_norm.weight"}, - {"model.positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"}, - {"model.token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"}, - {"model.text_projection", "transformer.text_model.text_projection"}, - {"model.visual.class_embedding", "transformer.vision_model.embeddings.class_embedding"}, - {"model.visual.conv1.weight", "transformer.vision_model.embeddings.patch_embedding.weight"}, - {"model.visual.ln_post.bias", "transformer.vision_model.post_layernorm.bias"}, - {"model.visual.ln_post.weight", "transformer.vision_model.post_layernorm.weight"}, - {"model.visual.ln_pre.bias", "transformer.vision_model.pre_layernorm.bias"}, - {"model.visual.ln_pre.weight", "transformer.vision_model.pre_layernorm.weight"}, - {"model.visual.positional_embedding", "transformer.vision_model.embeddings.position_embedding.weight"}, - {"model.visual.proj", "transformer.visual_projection.weight"}, -}; - -std::unordered_map open_clip_to_hf_clip_resblock = { - {"attn.in_proj_bias", "self_attn.in_proj.bias"}, - {"attn.in_proj_weight", "self_attn.in_proj.weight"}, - {"attn.out_proj.bias", "self_attn.out_proj.bias"}, - {"attn.out_proj.weight", "self_attn.out_proj.weight"}, - {"ln_1.bias", "layer_norm1.bias"}, - {"ln_1.weight", "layer_norm1.weight"}, - {"ln_2.bias", "layer_norm2.bias"}, - {"ln_2.weight", "layer_norm2.weight"}, - {"mlp.c_fc.bias", "mlp.fc1.bias"}, - {"mlp.c_fc.weight", "mlp.fc1.weight"}, - {"mlp.c_proj.bias", "mlp.fc2.bias"}, - {"mlp.c_proj.weight", "mlp.fc2.weight"}, -}; - -std::unordered_map cond_model_name_map = { - {"transformer.vision_model.pre_layrnorm.weight", "transformer.vision_model.pre_layernorm.weight"}, - {"transformer.vision_model.pre_layrnorm.bias", "transformer.vision_model.pre_layernorm.bias"}, -}; - -std::unordered_map vae_decoder_name_map = { - {"first_stage_model.decoder.mid.attn_1.to_k.bias", "first_stage_model.decoder.mid.attn_1.k.bias"}, - {"first_stage_model.decoder.mid.attn_1.to_k.weight", "first_stage_model.decoder.mid.attn_1.k.weight"}, - {"first_stage_model.decoder.mid.attn_1.to_out.0.bias", "first_stage_model.decoder.mid.attn_1.proj_out.bias"}, - {"first_stage_model.decoder.mid.attn_1.to_out.0.weight", "first_stage_model.decoder.mid.attn_1.proj_out.weight"}, - {"first_stage_model.decoder.mid.attn_1.to_q.bias", "first_stage_model.decoder.mid.attn_1.q.bias"}, - {"first_stage_model.decoder.mid.attn_1.to_q.weight", "first_stage_model.decoder.mid.attn_1.q.weight"}, - {"first_stage_model.decoder.mid.attn_1.to_v.bias", "first_stage_model.decoder.mid.attn_1.v.bias"}, - {"first_stage_model.decoder.mid.attn_1.to_v.weight", "first_stage_model.decoder.mid.attn_1.v.weight"}, -}; - -std::unordered_map pmid_v2_name_map = { - {"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"}, - {"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.3.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"}, - {"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"}, - {"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.3.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc2.weight"}, - {"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"}, - {"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.3.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"}, - {"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"}, - {"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.3.weight", - "pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc2.weight"}, - {"pmid.qformer_perceiver.token_proj.0.bias", - "pmid.qformer_perceiver.token_proj.fc1.bias"}, - {"pmid.qformer_perceiver.token_proj.2.bias", - "pmid.qformer_perceiver.token_proj.fc2.bias"}, - {"pmid.qformer_perceiver.token_proj.0.weight", - "pmid.qformer_perceiver.token_proj.fc1.weight"}, - {"pmid.qformer_perceiver.token_proj.2.weight", - "pmid.qformer_perceiver.token_proj.fc2.weight"}, -}; - -std::unordered_map qwenvl_name_map{ - {"token_embd.", "model.embed_tokens."}, - {"blk.", "model.layers."}, - {"attn_q.", "self_attn.q_proj."}, - {"attn_k.", "self_attn.k_proj."}, - {"attn_v.", "self_attn.v_proj."}, - {"attn_output.", "self_attn.o_proj."}, - {"attn_norm.", "input_layernorm."}, - {"ffn_down.", "mlp.down_proj."}, - {"ffn_gate.", "mlp.gate_proj."}, - {"ffn_up.", "mlp.up_proj."}, - {"ffn_norm.", "post_attention_layernorm."}, - {"output_norm.", "model.norm."}, -}; - -std::unordered_map qwenvl_vision_name_map{ - {"mm.", "merger.mlp."}, - {"v.post_ln.", "merger.ln_q."}, - {"v.patch_embd.weight", "patch_embed.proj.0.weight"}, - {"patch_embed.proj.0.weight.1", "patch_embed.proj.1.weight"}, - {"v.patch_embd.weight.1", "patch_embed.proj.1.weight"}, - {"v.blk.", "blocks."}, - {"attn_q.", "attn.q_proj."}, - {"attn_k.", "attn.k_proj."}, - {"attn_v.", "attn.v_proj."}, - {"attn_out.", "attn.proj."}, - {"ffn_down.", "mlp.down_proj."}, - {"ffn_gate.", "mlp.gate_proj."}, - {"ffn_up.", "mlp.up_proj."}, - {"ln1.", "norm1."}, - {"ln2.", "norm2."}, -}; - -std::string convert_cond_model_name(const std::string& name) { - std::string new_name = name; - std::string prefix; - if (contains(new_name, ".enc.")) { - // llama.cpp naming convention for T5 - size_t pos = new_name.find(".enc."); - if (pos != std::string::npos) { - new_name.replace(pos, 5, ".encoder."); - } - pos = new_name.find("blk."); - if (pos != std::string::npos) { - new_name.replace(pos, 4, "block."); - } - pos = new_name.find("output_norm."); - if (pos != std::string::npos) { - new_name.replace(pos, 12, "final_layer_norm."); - } - pos = new_name.find("attn_k."); - if (pos != std::string::npos) { - new_name.replace(pos, 7, "layer.0.SelfAttention.k."); - } - pos = new_name.find("attn_v."); - if (pos != std::string::npos) { - new_name.replace(pos, 7, "layer.0.SelfAttention.v."); - } - pos = new_name.find("attn_o."); - if (pos != std::string::npos) { - new_name.replace(pos, 7, "layer.0.SelfAttention.o."); - } - pos = new_name.find("attn_q."); - if (pos != std::string::npos) { - new_name.replace(pos, 7, "layer.0.SelfAttention.q."); - } - pos = new_name.find("attn_norm."); - if (pos != std::string::npos) { - new_name.replace(pos, 10, "layer.0.layer_norm."); - } - pos = new_name.find("ffn_norm."); - if (pos != std::string::npos) { - new_name.replace(pos, 9, "layer.1.layer_norm."); - } - pos = new_name.find("ffn_up."); - if (pos != std::string::npos) { - new_name.replace(pos, 7, "layer.1.DenseReluDense.wi_1."); - } - pos = new_name.find("ffn_down."); - if (pos != std::string::npos) { - new_name.replace(pos, 9, "layer.1.DenseReluDense.wo."); - } - pos = new_name.find("ffn_gate."); - if (pos != std::string::npos) { - new_name.replace(pos, 9, "layer.1.DenseReluDense.wi_0."); - } - pos = new_name.find("attn_rel_b."); - if (pos != std::string::npos) { - new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias."); - } - } else if (contains(name, "qwen2vl")) { - if (contains(name, "qwen2vl.visual")) { - for (auto kv : qwenvl_vision_name_map) { - size_t pos = new_name.find(kv.first); - if (pos != std::string::npos) { - new_name.replace(pos, kv.first.size(), kv.second); - } - } - } else { - for (auto kv : qwenvl_name_map) { - size_t pos = new_name.find(kv.first); - if (pos != std::string::npos) { - new_name.replace(pos, kv.first.size(), kv.second); - } - } - } - } else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") { - new_name = "text_encoders.t5xxl.transformer.shared.weight"; - } - - if (starts_with(new_name, "conditioner.embedders.0.open_clip.")) { - prefix = "cond_stage_model."; - new_name = new_name.substr(strlen("conditioner.embedders.0.open_clip.")); - } else if (starts_with(new_name, "conditioner.embedders.0.")) { - prefix = "cond_stage_model."; - new_name = new_name.substr(strlen("conditioner.embedders.0.")); - } else if (starts_with(new_name, "conditioner.embedders.1.")) { - prefix = "cond_stage_model.1."; - new_name = new_name.substr(strlen("conditioner.embedders.0.")); - } else if (starts_with(new_name, "cond_stage_model.")) { - prefix = "cond_stage_model."; - new_name = new_name.substr(strlen("cond_stage_model.")); - } else if (ends_with(new_name, "vision_model.visual_projection.weight")) { - prefix = new_name.substr(0, new_name.size() - strlen("vision_model.visual_projection.weight")); - new_name = prefix + "visual_projection.weight"; - return new_name; - } else if (ends_with(new_name, "transformer.text_projection.weight")) { - prefix = new_name.substr(0, new_name.size() - strlen("transformer.text_projection.weight")); - new_name = prefix + "transformer.text_model.text_projection"; - return new_name; - } else { - return new_name; - } - - if (new_name == "model.text_projection.weight") { - new_name = "transformer.text_model.text_projection"; - } - - if (open_clip_to_hf_clip_model.find(new_name) != open_clip_to_hf_clip_model.end()) { - new_name = open_clip_to_hf_clip_model[new_name]; - } - - if (cond_model_name_map.find(new_name) != cond_model_name_map.end()) { - new_name = cond_model_name_map[new_name]; - } - - std::string open_clip_resblock_prefix = "model.transformer.resblocks."; - std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers."; - - auto replace_suffix = [&]() { - if (new_name.find(open_clip_resblock_prefix) == 0) { - std::string remain = new_name.substr(open_clip_resblock_prefix.length()); - std::string idx = remain.substr(0, remain.find(".")); - std::string suffix = remain.substr(idx.length() + 1); - - if (open_clip_to_hf_clip_resblock.find(suffix) != open_clip_to_hf_clip_resblock.end()) { - std::string new_suffix = open_clip_to_hf_clip_resblock[suffix]; - new_name = hf_clip_resblock_prefix + idx + "." + new_suffix; - } - } - }; - - replace_suffix(); - - open_clip_resblock_prefix = "model.visual.transformer.resblocks."; - hf_clip_resblock_prefix = "transformer.vision_model.encoder.layers."; - - replace_suffix(); - - return prefix + new_name; -} - -std::string convert_vae_decoder_name(const std::string& name) { - if (vae_decoder_name_map.find(name) != vae_decoder_name_map.end()) { - return vae_decoder_name_map[name]; - } - return name; -} - -std::string convert_pmid_v2_name(const std::string& name) { - if (pmid_v2_name_map.find(name) != pmid_v2_name_map.end()) { - return pmid_v2_name_map[name]; - } - return name; -} - -/* If not a SDXL LoRA the unet" prefix will have already been replaced by this - * point and "te2" and "te1" don't seem to appear in non-SDXL only "te_" */ -std::string convert_sdxl_lora_name(std::string tensor_name) { - const std::pair sdxl_lora_name_lookup[] = { - {"unet", "model_diffusion_model"}, - {"te2", "cond_stage_model_1_transformer"}, - {"te1", "cond_stage_model_transformer"}, - {"text_encoder_2", "cond_stage_model_1_transformer"}, - {"text_encoder", "cond_stage_model_transformer"}, - }; - for (auto& pair_i : sdxl_lora_name_lookup) { - if (tensor_name.compare(0, pair_i.first.length(), pair_i.first) == 0) { - tensor_name = std::regex_replace(tensor_name, std::regex(pair_i.first), pair_i.second); - break; - } - } - return tensor_name; -} - -std::unordered_map> suffix_conversion_underline = { - { - "attentions", - { - {"to_k", "k"}, - {"to_q", "q"}, - {"to_v", "v"}, - {"to_out_0", "proj_out"}, - {"group_norm", "norm"}, - {"key", "k"}, - {"query", "q"}, - {"value", "v"}, - {"proj_attn", "proj_out"}, - }, - }, - { - "resnets", - { - {"conv1", "in_layers_2"}, - {"conv2", "out_layers_3"}, - {"norm1", "in_layers_0"}, - {"norm2", "out_layers_0"}, - {"time_emb_proj", "emb_layers_1"}, - {"conv_shortcut", "skip_connection"}, - }, - }, -}; - -std::unordered_map> suffix_conversion_dot = { - { - "attentions", - { - {"to_k", "k"}, - {"to_q", "q"}, - {"to_v", "v"}, - {"to_out.0", "proj_out"}, - {"group_norm", "norm"}, - {"key", "k"}, - {"query", "q"}, - {"value", "v"}, - {"proj_attn", "proj_out"}, - }, - }, - { - "resnets", - { - {"conv1", "in_layers.2"}, - {"conv2", "out_layers.3"}, - {"norm1", "in_layers.0"}, - {"norm2", "out_layers.0"}, - {"time_emb_proj", "emb_layers.1"}, - {"conv_shortcut", "skip_connection"}, - }, - }, -}; - -std::string convert_diffusers_name_to_compvis(std::string key, char seq) { - std::vector m; - - auto match = [](std::vector& match_list, const std::regex& regex, const std::string& key) { - auto r = std::smatch{}; - if (!std::regex_match(key, r, regex)) { - return false; - } - - match_list.clear(); - for (size_t i = 1; i < r.size(); ++i) { - match_list.push_back(r.str(i)); - } - return true; - }; - - std::unordered_map> suffix_conversion; - if (seq == '_') { - suffix_conversion = suffix_conversion_underline; - } else { - suffix_conversion = suffix_conversion_dot; - } - - auto get_converted_suffix = [&suffix_conversion](const std::string& outer_key, const std::string& inner_key) { - auto outer_iter = suffix_conversion.find(outer_key); - if (outer_iter != suffix_conversion.end()) { - auto inner_iter = outer_iter->second.find(inner_key); - if (inner_iter != outer_iter->second.end()) { - return inner_iter->second; - } - } - return inner_key; - }; - - // convert attn to out - if (ends_with(key, "to_out")) { - key += format("%c0", seq); - } - - // unet - if (match(m, std::regex(format("unet%cconv_in(.*)", seq)), key)) { - return format("model%cdiffusion_model%cinput_blocks%c0%c0", seq, seq, seq, seq) + m[0]; - } - - if (match(m, std::regex(format("unet%cconv%cout(.*)", seq, seq)), key)) { - return format("model%cdiffusion_model%cout%c2", seq, seq, seq) + m[0]; - } - - if (match(m, std::regex(format("unet%cconv_norm_out(.*)", seq)), key)) { - return format("model%cdiffusion_model%cout%c0", seq, seq, seq) + m[0]; - } - - if (match(m, std::regex(format("unet%ctime_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) { - return format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1]; - } - - if (match(m, std::regex(format("unet%cadd_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) { - return format("model%cdiffusion_model%clabel_emb%c0%c", seq, seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1]; - } - - if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { - std::string suffix = get_converted_suffix(m[1], m[3]); - // LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str()); - return format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(1 + std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq + - (m[1] == "attentions" ? "1" : "0") + seq + suffix; - } - - if (match(m, std::regex(format("unet%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq)), key)) { - std::string suffix = get_converted_suffix(m[0], m[2]); - return format("model%cdiffusion_model%cmiddle_block%c", seq, seq, seq) + (m[0] == "attentions" ? "1" : std::to_string(std::stoi(m[1]) * 2)) + - seq + suffix; - } - - if (match(m, std::regex(format("unet%cup_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { - std::string suffix = get_converted_suffix(m[1], m[3]); - return format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq + - (m[1] == "attentions" ? "1" : "0") + seq + suffix; - } - - if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) { - return format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(3 + std::stoi(m[0]) * 3) + seq + "0" + seq + "op"; - } - - if (match(m, std::regex(format("unet%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) { - return format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(2 + std::stoi(m[0]) * 3) + seq + - (std::stoi(m[0]) > 0 ? "2" : "1") + seq + "conv"; - } - - // clip - if (match(m, std::regex(format("te%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { - return format("cond_stage_model%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq) + m[0] + seq + m[1]; - } - - if (match(m, std::regex(format("te%ctext_model(.*)", seq)), key)) { - return format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0]; - } - - // clip-g - if (match(m, std::regex(format("te%c1%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) { - return format("cond_stage_model%c1%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq, seq) + m[0] + seq + m[1]; - } - - if (match(m, std::regex(format("te%c1%ctext_model(.*)", seq, seq)), key)) { - return format("cond_stage_model%c1%ctransformer%ctext_model", seq, seq, seq) + m[0]; - } - - if (match(m, std::regex(format("te%c1%ctext_projection", seq, seq)), key)) { - return format("cond_stage_model%c1%ctransformer%ctext_model%ctext_projection", seq, seq, seq, seq); - } - - // vae - if (match(m, std::regex(format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) { - return format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str()); - } - - if (match(m, std::regex(format("vae%c(.*)%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { - std::string suffix; - std::string block_name; - if (m[1] == "attentions") { - block_name = "attn"; - suffix = get_converted_suffix(m[1], m[3]); - } else { - block_name = "block"; - suffix = m[3]; - } - return format("first_stage_model%c%s%cmid%c%s_%d%c%s", - seq, m[0].c_str(), seq, seq, block_name.c_str(), std::stoi(m[2]) + 1, seq, suffix.c_str()); - } - - if (match(m, std::regex(format("vae%c(.*)%cup_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) { - std::string suffix = m[3]; - if (suffix == "conv_shortcut") { - suffix = "nin_shortcut"; - } - return format("first_stage_model%c%s%cup%c%d%cblock%c%s%c%s", - seq, m[0].c_str(), seq, seq, 3 - std::stoi(m[1]), seq, seq, m[2].c_str(), seq, suffix.c_str()); - } - - if (match(m, std::regex(format("vae%c(.*)%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) { - return format("first_stage_model%c%s%cdown%c%d%cdownsample%cconv", - seq, m[0].c_str(), seq, seq, std::stoi(m[1]), seq, seq); - } - - if (match(m, std::regex(format("vae%c(.*)%cdown_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) { - std::string suffix = m[3]; - if (suffix == "conv_shortcut") { - suffix = "nin_shortcut"; - } - return format("first_stage_model%c%s%cdown%c%d%cblock%c%s%c%s", - seq, m[0].c_str(), seq, seq, std::stoi(m[1]), seq, seq, m[2].c_str(), seq, suffix.c_str()); - } - - if (match(m, std::regex(format("vae%c(.*)%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) { - return format("first_stage_model%c%s%cup%c%d%cupsample%cconv", - seq, m[0].c_str(), seq, seq, 3 - std::stoi(m[1]), seq, seq); - } - - if (match(m, std::regex(format("vae%c(.*)", seq)), key)) { - return format("first_stage_model%c", seq) + m[0]; - } - - return key; -} - -std::string convert_tensor_name(std::string name) { - if (starts_with(name, "diffusion_model")) { - name = "model." + name; - } - if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.0.")) { - name.replace(0, sizeof("model.diffusion_model.up_blocks.0.attentions.0.") - 1, - "model.diffusion_model.output_blocks.0.1."); - } - if (starts_with(name, "model.diffusion_model.up_blocks.0.attentions.1.")) { - name.replace(0, sizeof("model.diffusion_model.up_blocks.0.attentions.1.") - 1, - "model.diffusion_model.output_blocks.1.1."); - } - // size_t pos = name.find("lora_A"); - // if (pos != std::string::npos) { - // name.replace(pos, strlen("lora_A"), "lora_up"); - // } - // pos = name.find("lora_B"); - // if (pos != std::string::npos) { - // name.replace(pos, strlen("lora_B"), "lora_down"); - // } - std::string new_name = name; - if (starts_with(name, "cond_stage_model.") || - starts_with(name, "conditioner.embedders.") || - starts_with(name, "text_encoders.") || - ends_with(name, ".vision_model.visual_projection.weight") || - starts_with(name, "qwen2vl")) { - new_name = convert_cond_model_name(name); - } else if (starts_with(name, "first_stage_model.decoder")) { - new_name = convert_vae_decoder_name(name); - } else if (starts_with(name, "pmid.qformer_perceiver")) { - new_name = convert_pmid_v2_name(name); - } else if (starts_with(name, "control_model.")) { // for controlnet pth models - size_t pos = name.find('.'); - if (pos != std::string::npos) { - new_name = name.substr(pos + 1); - } - } else if (starts_with(name, "lora_")) { // for lora - size_t pos = name.find('.'); - if (pos != std::string::npos) { - std::string name_without_network_parts = name.substr(5, pos - 5); - std::string network_part = name.substr(pos + 1); - - // LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str()); - std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '_'); - /* For dealing with the new SDXL LoRA tensor naming convention */ - new_key = convert_sdxl_lora_name(new_key); - - if (new_key.empty()) { - new_name = name; - } else { - new_name = "lora." + new_key + "." + network_part; - } - } else { - new_name = name; - } - } else if (ends_with(name, ".diff") || ends_with(name, ".diff_b")) { - new_name = "lora." + name; - } else if (contains(name, "lora_up") || contains(name, "lora_down") || - contains(name, "lora.up") || contains(name, "lora.down") || - contains(name, "lora_linear") || ends_with(name, ".alpha")) { - size_t pos = new_name.find(".processor"); - if (pos != std::string::npos) { - new_name.replace(pos, strlen(".processor"), ""); - } - // if (starts_with(new_name, "transformer.transformer_blocks") || starts_with(new_name, "transformer.single_transformer_blocks")) { - // new_name = "model.diffusion_model." + new_name; - // } - if (ends_with(name, ".alpha")) { - pos = new_name.rfind("alpha"); - } else { - pos = new_name.rfind("lora"); - } - if (pos != std::string::npos) { - std::string name_without_network_parts = new_name.substr(0, pos - 1); - std::string network_part = new_name.substr(pos); - // LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str()); - std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.'); - new_key = convert_sdxl_lora_name(new_key); - replace_all_chars(new_key, '.', '_'); - size_t npos = network_part.rfind("_linear_layer"); - if (npos != std::string::npos) { - network_part.replace(npos, strlen("_linear_layer"), ""); - } - if (starts_with(network_part, "lora.")) { - network_part = "lora_" + network_part.substr(5); - } - if (new_key.size() > 0) { - new_name = "lora." + new_key + "." + network_part; - } - // LOG_DEBUG("new name: %s", new_name.c_str()); - } - } else if (starts_with(name, "unet") || starts_with(name, "vae") || starts_with(name, "te")) { // for diffuser - size_t pos = name.find_last_of('.'); - if (pos != std::string::npos) { - std::string name_without_network_parts = name.substr(0, pos); - std::string network_part = name.substr(pos + 1); - // LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str()); - std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.'); - if (new_key.empty()) { - new_name = name; - } else if (new_key == "cond_stage_model.1.transformer.text_model.text_projection") { - new_name = new_key; - } else { - new_name = new_key + "." + network_part; - } - } else { - new_name = name; - } - } else { - new_name = name; - } - // if (new_name != name) { - // LOG_DEBUG("%s => %s", name.c_str(), new_name.c_str()); - // } - return new_name; -} - float bf16_to_f32(uint16_t bfloat16) { uint32_t val_bits = (static_cast(bfloat16) << 16); return *reinterpret_cast(&val_bits); @@ -916,9 +293,7 @@ void convert_tensor(void* src, /*================================================= ModelLoader ==================================================*/ void ModelLoader::add_tensor_storage(const TensorStorage& tensor_storage) { - TensorStorage copy = tensor_storage; - copy.name = convert_tensor_name(copy.name); - tensor_storage_map[copy.name] = std::move(copy); + tensor_storage_map[tensor_storage.name] = tensor_storage; } bool is_zip_file(const std::string& file_path) { @@ -1012,6 +387,31 @@ bool ModelLoader::init_from_file(const std::string& file_path, const std::string } } +void ModelLoader::convert_tensors_name() { + SDVersion version = (version_ == VERSION_COUNT) ? get_sd_version() : version_; + String2TensorStorage new_map; + + for (auto& [_, tensor_storage] : tensor_storage_map) { + auto new_name = convert_tensor_name(tensor_storage.name, version); + // LOG_DEBUG("%s -> %s", tensor_storage.name.c_str(), new_name.c_str()); + tensor_storage.name = new_name; + new_map[new_name] = std::move(tensor_storage); + } + + tensor_storage_map.swap(new_map); +} + +bool ModelLoader::init_from_file_and_convert_name(const std::string& file_path, const std::string& prefix, SDVersion version) { + if (version_ == VERSION_COUNT && version != VERSION_COUNT) { + version_ = version; + } + if (!init_from_file(file_path, prefix)) { + return false; + } + convert_tensors_name(); + return true; +} + /*================================================= GGUFModelLoader ==================================================*/ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::string& prefix) { @@ -1259,32 +659,6 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s if (!init_from_safetensors_file(unet_path, "unet.")) { return false; } - for (auto& [name, tensor_storage] : tensor_storage_map) { - if (name.find("add_embedding") != std::string::npos || name.find("label_emb") != std::string::npos) { - // probably SDXL - LOG_DEBUG("Fixing name for SDXL output blocks.2.2"); - String2TensorStorage new_tensor_storage_map; - - for (auto& [name, tensor_storage] : tensor_storage_map) { - int len = 34; - auto pos = tensor_storage.name.find("unet.up_blocks.0.upsamplers.0.conv"); - if (pos == std::string::npos) { - len = 44; - pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv"); - } - if (pos != std::string::npos) { - std::string new_name = "model.diffusion_model.output_blocks.2.2.conv" + name.substr(len); - LOG_DEBUG("NEW NAME: %s", new_name.c_str()); - tensor_storage.name = new_name; - new_tensor_storage_map[new_name] = tensor_storage; - } else { - new_tensor_storage_map[name] = tensor_storage; - } - } - tensor_storage_map = new_tensor_storage_map; - break; - } - } if (!init_from_safetensors_file(vae_path, "vae.")) { LOG_WARN("Couldn't find working VAE in %s", file_path.c_str()); @@ -1925,7 +1299,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_thread int64_t start_time = ggml_time_ms(); std::vector processed_tensor_storages; - for (auto& [name, tensor_storage] : tensor_storage_map) { + for (const auto& [name, tensor_storage] : tensor_storage_map) { if (is_unused_tensor(tensor_storage.name)) { continue; } @@ -2394,6 +1768,7 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa return false; } } + model_loader.convert_tensors_name(); bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules); return success; } diff --git a/model.h b/model.h index 583a0146..588f9821 100644 --- a/model.h +++ b/model.h @@ -15,6 +15,7 @@ #include "ggml.h" #include "gguf.h" #include "json.hpp" +#include "ordered_map.hpp" #include "zip.h" #define SD_MAX_DIMS 5 @@ -108,7 +109,11 @@ static inline bool sd_version_is_qwen_image(SDVersion version) { } static inline bool sd_version_is_inpaint(SDVersion version) { - if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) { + if (version == VERSION_SD1_INPAINT || + version == VERSION_SD2_INPAINT || + version == VERSION_SDXL_INPAINT || + version == VERSION_FLUX_FILL || + version == VERSION_FLEX_2) { return true; } return false; @@ -253,10 +258,11 @@ struct TensorStorage { typedef std::function on_new_tensor_cb_t; -typedef std::map String2TensorStorage; +typedef OrderedMap String2TensorStorage; class ModelLoader { protected: + SDVersion version_ = VERSION_COUNT; std::vector file_paths_; String2TensorStorage tensor_storage_map; @@ -276,6 +282,10 @@ class ModelLoader { public: bool init_from_file(const std::string& file_path, const std::string& prefix = ""); + void convert_tensors_name(); + bool init_from_file_and_convert_name(const std::string& file_path, + const std::string& prefix = "", + SDVersion version = VERSION_COUNT); SDVersion get_sd_version(); std::map get_wtype_stat(); std::map get_conditioner_wtype_stat(); diff --git a/name_conversion.cpp b/name_conversion.cpp new file mode 100644 index 00000000..ea2702a7 --- /dev/null +++ b/name_conversion.cpp @@ -0,0 +1,1028 @@ +#include +#include + +#include "name_conversion.h" +#include "util.h" + +void replace_with_name_map(std::string& name, const std::vector>& name_map) { + for (auto kv : name_map) { + size_t pos = name.find(kv.first); + if (pos != std::string::npos) { + name.replace(pos, kv.first.size(), kv.second); + } + } +} + +void replace_with_prefix_map(std::string& name, const std::vector>& prefix_map) { + for (const auto& [old_prefix, new_prefix] : prefix_map) { + if (starts_with(name, old_prefix)) { + name = new_prefix + name.substr(old_prefix.size()); + break; + } + } +} + +void replace_with_prefix_map(std::string& name, const std::unordered_map& prefix_map) { + for (const auto& [old_prefix, new_prefix] : prefix_map) { + if (starts_with(name, old_prefix)) { + name = new_prefix + name.substr(old_prefix.size()); + break; + } + } +} + +std::string convert_open_clip_to_hf_clip_name(std::string name) { + static std::unordered_map open_clip_to_hf_clip_model = { + {"model.ln_final.bias", "transformer.text_model.final_layer_norm.bias"}, + {"model.ln_final.weight", "transformer.text_model.final_layer_norm.weight"}, + {"model.positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"}, + {"model.token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"}, + {"model.text_projection", "transformer.text_model.text_projection"}, + {"model.visual.class_embedding", "transformer.vision_model.embeddings.class_embedding"}, + {"model.visual.conv1.weight", "transformer.vision_model.embeddings.patch_embedding.weight"}, + {"model.visual.ln_post.bias", "transformer.vision_model.post_layernorm.bias"}, + {"model.visual.ln_post.weight", "transformer.vision_model.post_layernorm.weight"}, + {"model.visual.ln_pre.bias", "transformer.vision_model.pre_layernorm.bias"}, + {"model.visual.ln_pre.weight", "transformer.vision_model.pre_layernorm.weight"}, + {"model.visual.positional_embedding", "transformer.vision_model.embeddings.position_embedding.weight"}, + {"model.visual.proj", "transformer.visual_projection.weight"}, + }; + + static std::unordered_map open_clip_to_hf_clip_resblock = { + {"attn.in_proj_bias", "self_attn.in_proj.bias"}, + {"attn.in_proj_weight", "self_attn.in_proj.weight"}, + {"attn.out_proj.bias", "self_attn.out_proj.bias"}, + {"attn.out_proj.weight", "self_attn.out_proj.weight"}, + {"ln_1.bias", "layer_norm1.bias"}, + {"ln_1.weight", "layer_norm1.weight"}, + {"ln_2.bias", "layer_norm2.bias"}, + {"ln_2.weight", "layer_norm2.weight"}, + {"mlp.c_fc.bias", "mlp.fc1.bias"}, + {"mlp.c_fc.weight", "mlp.fc1.weight"}, + {"mlp.c_proj.bias", "mlp.fc2.bias"}, + {"mlp.c_proj.weight", "mlp.fc2.weight"}, + }; + + static std::unordered_map cond_model_name_map = { + {"transformer.vision_model.pre_layrnorm.weight", "transformer.vision_model.pre_layernorm.weight"}, + {"transformer.vision_model.pre_layrnorm.bias", "transformer.vision_model.pre_layernorm.bias"}, + }; + + if (open_clip_to_hf_clip_model.find(name) != open_clip_to_hf_clip_model.end()) { + name = open_clip_to_hf_clip_model[name]; + } + + if (cond_model_name_map.find(name) != cond_model_name_map.end()) { + name = cond_model_name_map[name]; + } + + std::string open_clip_resblock_prefix = "model.transformer.resblocks."; + std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers."; + + auto replace_suffix = [&]() { + if (name.find(open_clip_resblock_prefix) == 0) { + std::string remain = name.substr(open_clip_resblock_prefix.length()); + std::string idx = remain.substr(0, remain.find(".")); + std::string suffix = remain.substr(idx.length() + 1); + + if (open_clip_to_hf_clip_resblock.find(suffix) != open_clip_to_hf_clip_resblock.end()) { + std::string new_suffix = open_clip_to_hf_clip_resblock[suffix]; + name = hf_clip_resblock_prefix + idx + "." + new_suffix; + } + } + }; + + replace_suffix(); + + open_clip_resblock_prefix = "model.visual.transformer.resblocks."; + hf_clip_resblock_prefix = "transformer.vision_model.encoder.layers."; + + replace_suffix(); + + return name; +} + +std::string convert_cond_stage_model_name(std::string name, std::string prefix) { + static const std::vector> clip_name_map{ + {"transformer.text_projection.weight", "transformer.text_model.text_projection"}, + {"model.text_projection.weight", "transformer.text_model.text_projection"}, + {"vision_model.visual_projection.weight", "visual_projection.weight"}, + }; + + // llama.cpp to original + static const std::vector> t5_name_map{ + {"enc.", "encoder."}, + {"blk.", "block."}, + {"output_norm.", "final_layer_norm."}, + {"attn_q.", "layer.0.SelfAttention.q."}, + {"attn_k.", "layer.0.SelfAttention.k."}, + {"attn_v.", "layer.0.SelfAttention.v."}, + {"attn_o.", "layer.0.SelfAttention.o."}, + {"attn_norm.", "layer.0.layer_norm."}, + {"ffn_norm.", "layer.1.layer_norm."}, + {"ffn_up.", "layer.1.DenseReluDense.wi_1."}, + {"ffn_down.", "layer.1.DenseReluDense.wo."}, + {"ffn_gate.", "layer.1.DenseReluDense.wi_0."}, + {"attn_rel_b.", "layer.0.SelfAttention.relative_attention_bias."}, + {"token_embd.", "shared."}, + }; + + static const std::vector> qwenvl_name_map{ + {"token_embd.", "model.embed_tokens."}, + {"blk.", "model.layers."}, + {"attn_q.", "self_attn.q_proj."}, + {"attn_k.", "self_attn.k_proj."}, + {"attn_v.", "self_attn.v_proj."}, + {"attn_output.", "self_attn.o_proj."}, + {"attn_norm.", "input_layernorm."}, + {"ffn_down.", "mlp.down_proj."}, + {"ffn_gate.", "mlp.gate_proj."}, + {"ffn_up.", "mlp.up_proj."}, + {"ffn_norm.", "post_attention_layernorm."}, + {"output_norm.", "model.norm."}, + }; + + static const std::vector> qwenvl_vision_name_map{ + {"mm.", "merger.mlp."}, + {"v.post_ln.", "merger.ln_q."}, + {"v.patch_embd.weight", "patch_embed.proj.0.weight"}, + {"patch_embed.proj.0.weight.1", "patch_embed.proj.1.weight"}, + {"v.patch_embd.weight.1", "patch_embed.proj.1.weight"}, + {"v.blk.", "blocks."}, + {"attn_q.", "attn.q_proj."}, + {"attn_k.", "attn.k_proj."}, + {"attn_v.", "attn.v_proj."}, + {"attn_out.", "attn.proj."}, + {"ffn_down.", "mlp.down_proj."}, + {"ffn_gate.", "mlp.gate_proj."}, + {"ffn_up.", "mlp.up_proj."}, + {"ln1.", "norm1."}, + {"ln2.", "norm2."}, + }; + if (contains(name, "t5xxl")) { + replace_with_name_map(name, t5_name_map); + } else if (contains(name, "qwen2vl")) { + if (contains(name, "qwen2vl.visual")) { + replace_with_name_map(name, qwenvl_vision_name_map); + } else { + replace_with_name_map(name, qwenvl_name_map); + } + } else { + name = convert_open_clip_to_hf_clip_name(name); + replace_with_name_map(name, clip_name_map); + } + return name; +} + +// ref: https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py +std::string convert_diffusers_unet_to_original_sd1(std::string name) { + // (stable-diffusion, HF Diffusers) + static const std::vector> unet_conversion_map = { + {"time_embed.0.weight", "time_embedding.linear_1.weight"}, + {"time_embed.0.bias", "time_embedding.linear_1.bias"}, + {"time_embed.2.weight", "time_embedding.linear_2.weight"}, + {"time_embed.2.bias", "time_embedding.linear_2.bias"}, + {"input_blocks.0.0.weight", "conv_in.weight"}, + {"input_blocks.0.0.bias", "conv_in.bias"}, + {"out.0.weight", "conv_norm_out.weight"}, + {"out.0.bias", "conv_norm_out.bias"}, + {"out.2.weight", "conv_out.weight"}, + {"out.2.bias", "conv_out.bias"}, + }; + + static const std::vector> unet_conversion_map_resnet = { + {"in_layers.0", "norm1"}, + {"in_layers.2", "conv1"}, + {"out_layers.0", "norm2"}, + {"out_layers.3", "conv2"}, + {"emb_layers.1", "time_emb_proj"}, + {"skip_connection", "conv_shortcut"}, + }; + + static std::vector> unet_conversion_map_layer; + if (unet_conversion_map_layer.empty()) { + for (int i = 0; i < 4; ++i) { + // down_blocks + for (int j = 0; j < 2; ++j) { + std::string hf_down_res_prefix = "down_blocks." + std::to_string(i) + ".resnets." + std::to_string(j) + "."; + std::string sd_down_res_prefix = "input_blocks." + std::to_string(3 * i + j + 1) + ".0."; + unet_conversion_map_layer.emplace_back(sd_down_res_prefix, hf_down_res_prefix); + + if (i < 3) { + std::string hf_down_atn_prefix = "down_blocks." + std::to_string(i) + ".attentions." + std::to_string(j) + "."; + std::string sd_down_atn_prefix = "input_blocks." + std::to_string(3 * i + j + 1) + ".1."; + unet_conversion_map_layer.emplace_back(sd_down_atn_prefix, hf_down_atn_prefix); + } + } + + // up_blocks + for (int j = 0; j < 3; ++j) { + std::string hf_up_res_prefix = "up_blocks." + std::to_string(i) + ".resnets." + std::to_string(j) + "."; + std::string sd_up_res_prefix = "output_blocks." + std::to_string(3 * i + j) + ".0."; + unet_conversion_map_layer.emplace_back(sd_up_res_prefix, hf_up_res_prefix); + + if (/*i > 0*/ true) { // for tiny unet + std::string hf_up_atn_prefix = "up_blocks." + std::to_string(i) + ".attentions." + std::to_string(j) + "."; + std::string sd_up_atn_prefix = "output_blocks." + std::to_string(3 * i + j) + ".1."; + unet_conversion_map_layer.emplace_back(sd_up_atn_prefix, hf_up_atn_prefix); + } + } + + if (i < 3) { + std::string hf_downsample_prefix = "down_blocks." + std::to_string(i) + ".downsamplers.0.conv."; + std::string sd_downsample_prefix = "input_blocks." + std::to_string(3 * (i + 1)) + ".0.op."; + unet_conversion_map_layer.emplace_back(sd_downsample_prefix, hf_downsample_prefix); + + std::string hf_upsample_prefix = "up_blocks." + std::to_string(i) + ".upsamplers.0."; + std::string sd_upsample_prefix = "output_blocks." + std::to_string(3 * i + 2) + "." + std::to_string(i == 0 ? 1 : 2) + "."; + unet_conversion_map_layer.emplace_back(sd_upsample_prefix, hf_upsample_prefix); + } + } + + // mid block + unet_conversion_map_layer.emplace_back("middle_block.1.", "mid_block.attentions.0."); + for (int j = 0; j < 2; ++j) { + std::string hf_mid_res_prefix = "mid_block.resnets." + std::to_string(j) + "."; + std::string sd_mid_res_prefix = "middle_block." + std::to_string(2 * j) + "."; + unet_conversion_map_layer.emplace_back(sd_mid_res_prefix, hf_mid_res_prefix); + } + } + + std::string result = name; + + for (const auto& p : unet_conversion_map) { + if (result == p.second) { + result = p.first; + return result; + } + } + + if (contains(result, "resnets")) { + for (const auto& p : unet_conversion_map_resnet) { + size_t pos = result.find(p.second); + if (pos != std::string::npos) { + result.replace(pos, p.second.size(), p.first); + } + } + } + + for (const auto& p : unet_conversion_map_layer) { + size_t pos = result.find(p.second); + if (pos != std::string::npos) { + result.replace(pos, p.second.size(), p.first); + } + } + + return result; +} + +// ref: https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_sdxl.py + +std::string convert_diffusers_unet_to_original_sdxl(std::string name) { + // (stable-diffusion, HF Diffusers) + static const std::vector> unet_conversion_map = { + {"time_embed.0.weight", "time_embedding.linear_1.weight"}, + {"time_embed.0.bias", "time_embedding.linear_1.bias"}, + {"time_embed.2.weight", "time_embedding.linear_2.weight"}, + {"time_embed.2.bias", "time_embedding.linear_2.bias"}, + {"input_blocks.0.0.weight", "conv_in.weight"}, + {"input_blocks.0.0.bias", "conv_in.bias"}, + {"out.0.weight", "conv_norm_out.weight"}, + {"out.0.bias", "conv_norm_out.bias"}, + {"out.2.weight", "conv_out.weight"}, + {"out.2.bias", "conv_out.bias"}, + + // --- SDXL add_embedding mappings --- + {"label_emb.0.0.weight", "add_embedding.linear_1.weight"}, + {"label_emb.0.0.bias", "add_embedding.linear_1.bias"}, + {"label_emb.0.2.weight", "add_embedding.linear_2.weight"}, + {"label_emb.0.2.bias", "add_embedding.linear_2.bias"}, + }; + + static const std::vector> unet_conversion_map_resnet = { + {"in_layers.0", "norm1"}, + {"in_layers.2", "conv1"}, + {"out_layers.0", "norm2"}, + {"out_layers.3", "conv2"}, + {"emb_layers.1", "time_emb_proj"}, + {"skip_connection", "conv_shortcut"}, + }; + + static std::vector> unet_conversion_map_layer; + if (unet_conversion_map_layer.empty()) { + for (int i = 0; i < 3; ++i) { + // --- down_blocks --- + for (int j = 0; j < 2; ++j) { + std::string hf_down_res_prefix = "down_blocks." + std::to_string(i) + ".resnets." + std::to_string(j) + "."; + std::string sd_down_res_prefix = "input_blocks." + std::to_string(3 * i + j + 1) + ".0."; + unet_conversion_map_layer.emplace_back(sd_down_res_prefix, hf_down_res_prefix); + + if (i > 0) { + std::string hf_down_atn_prefix = "down_blocks." + std::to_string(i) + ".attentions." + std::to_string(j) + "."; + std::string sd_down_atn_prefix = "input_blocks." + std::to_string(3 * i + j + 1) + ".1."; + unet_conversion_map_layer.emplace_back(sd_down_atn_prefix, hf_down_atn_prefix); + } + } + + // --- up_blocks --- + for (int j = 0; j < 4; ++j) { + std::string hf_up_res_prefix = "up_blocks." + std::to_string(i) + ".resnets." + std::to_string(j) + "."; + std::string sd_up_res_prefix = "output_blocks." + std::to_string(3 * i + j) + ".0."; + unet_conversion_map_layer.emplace_back(sd_up_res_prefix, hf_up_res_prefix); + + if (i < 2) { + std::string hf_up_atn_prefix = "up_blocks." + std::to_string(i) + ".attentions." + std::to_string(j) + "."; + std::string sd_up_atn_prefix = "output_blocks." + std::to_string(3 * i + j) + ".1."; + unet_conversion_map_layer.emplace_back(sd_up_atn_prefix, hf_up_atn_prefix); + } + } + + if (i < 3) { + std::string hf_downsample_prefix = "down_blocks." + std::to_string(i) + ".downsamplers.0.conv."; + std::string sd_downsample_prefix = "input_blocks." + std::to_string(3 * (i + 1)) + ".0.op."; + unet_conversion_map_layer.emplace_back(sd_downsample_prefix, hf_downsample_prefix); + + std::string hf_upsample_prefix = "up_blocks." + std::to_string(i) + ".upsamplers.0."; + std::string sd_upsample_prefix = + "output_blocks." + std::to_string(3 * i + 2) + "." + std::to_string(i == 0 ? 1 : 2) + "."; + unet_conversion_map_layer.emplace_back(sd_upsample_prefix, hf_upsample_prefix); + } + } + + unet_conversion_map_layer.emplace_back("output_blocks.2.2.conv.", "output_blocks.2.1.conv."); + + // mid block + unet_conversion_map_layer.emplace_back("middle_block.1.", "mid_block.attentions.0."); + for (int j = 0; j < 2; ++j) { + std::string hf_mid_res_prefix = "mid_block.resnets." + std::to_string(j) + "."; + std::string sd_mid_res_prefix = "middle_block." + std::to_string(2 * j) + "."; + unet_conversion_map_layer.emplace_back(sd_mid_res_prefix, hf_mid_res_prefix); + } + } + + std::string result = name; + + for (const auto& p : unet_conversion_map) { + if (result == p.second) { + result = p.first; + return result; + } + } + + if (contains(result, "resnets")) { + for (const auto& p : unet_conversion_map_resnet) { + size_t pos = result.find(p.second); + if (pos != std::string::npos) { + result.replace(pos, p.second.size(), p.first); + } + } + } + + for (const auto& p : unet_conversion_map_layer) { + size_t pos = result.find(p.second); + if (pos != std::string::npos) { + result.replace(pos, p.second.size(), p.first); + } + } + + static const std::vector> name_map{ + {"to_out.weight", "to_out.0.weight"}, + {"to_out.bias", "to_out.0.bias"}, + }; + replace_with_name_map(result, name_map); + + return result; +} + +std::string convert_diffusers_dit_to_original_sd3(std::string name) { + int num_layers = 38; + static std::unordered_map sd3_name_map; + + if (sd3_name_map.empty()) { + // --- time_text_embed --- + sd3_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "t_embedder.mlp.0.weight"; + sd3_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "t_embedder.mlp.0.bias"; + sd3_name_map["time_text_embed.timestep_embedder.linear_2.weight"] = "t_embedder.mlp.2.weight"; + sd3_name_map["time_text_embed.timestep_embedder.linear_2.bias"] = "t_embedder.mlp.2.bias"; + + sd3_name_map["time_text_embed.text_embedder.linear_1.weight"] = "y_embedder.mlp.0.weight"; + sd3_name_map["time_text_embed.text_embedder.linear_1.bias"] = "y_embedder.mlp.0.bias"; + sd3_name_map["time_text_embed.text_embedder.linear_2.weight"] = "y_embedder.mlp.2.weight"; + sd3_name_map["time_text_embed.text_embedder.linear_2.bias"] = "y_embedder.mlp.2.bias"; + + sd3_name_map["pos_embed.pos_embed"] = "pos_embed"; + sd3_name_map["pos_embed.proj.weight"] = "x_embedder.proj.weight"; + sd3_name_map["pos_embed.proj.bias"] = "x_embedder.proj.bias"; + + // --- transformer blocks --- + for (int i = 0; i < num_layers; ++i) { + std::string block_prefix = "transformer_blocks." + std::to_string(i) + "."; + std::string dst_prefix = "joint_blocks." + std::to_string(i) + "."; + + sd3_name_map[block_prefix + "norm1.linear.weight"] = dst_prefix + "x_block.adaLN_modulation.1.weight"; + sd3_name_map[block_prefix + "norm1.linear.bias"] = dst_prefix + "x_block.adaLN_modulation.1.bias"; + sd3_name_map[block_prefix + "norm1_context.linear.weight"] = dst_prefix + "context_block.adaLN_modulation.1.weight"; + sd3_name_map[block_prefix + "norm1_context.linear.bias"] = dst_prefix + "context_block.adaLN_modulation.1.bias"; + + // attn + sd3_name_map[block_prefix + "attn.to_q.weight"] = dst_prefix + "x_block.attn.qkv.weight"; + sd3_name_map[block_prefix + "attn.to_q.bias"] = dst_prefix + "x_block.attn.qkv.bias"; + sd3_name_map[block_prefix + "attn.to_k.weight"] = dst_prefix + "x_block.attn.qkv.weight.1"; + sd3_name_map[block_prefix + "attn.to_k.bias"] = dst_prefix + "x_block.attn.qkv.bias.1"; + sd3_name_map[block_prefix + "attn.to_v.weight"] = dst_prefix + "x_block.attn.qkv.weight.2"; + sd3_name_map[block_prefix + "attn.to_v.bias"] = dst_prefix + "x_block.attn.qkv.bias.2"; + + sd3_name_map[block_prefix + "attn.add_q_proj.weight"] = dst_prefix + "context_block.attn.qkv.weight"; + sd3_name_map[block_prefix + "attn.add_q_proj.bias"] = dst_prefix + "context_block.attn.qkv.bias"; + sd3_name_map[block_prefix + "attn.add_k_proj.weight"] = dst_prefix + "context_block.attn.qkv.weight.1"; + sd3_name_map[block_prefix + "attn.add_k_proj.bias"] = dst_prefix + "context_block.attn.qkv.bias.1"; + sd3_name_map[block_prefix + "attn.add_v_proj.weight"] = dst_prefix + "context_block.attn.qkv.weight.2"; + sd3_name_map[block_prefix + "attn.add_v_proj.bias"] = dst_prefix + "context_block.attn.qkv.bias.2"; + + // attn2 + sd3_name_map[block_prefix + "attn2.to_q.weight"] = dst_prefix + "x_block.attn2.qkv.weight"; + sd3_name_map[block_prefix + "attn2.to_q.bias"] = dst_prefix + "x_block.attn2.qkv.bias"; + sd3_name_map[block_prefix + "attn2.to_k.weight"] = dst_prefix + "x_block.attn2.qkv.weight.1"; + sd3_name_map[block_prefix + "attn2.to_k.bias"] = dst_prefix + "x_block.attn2.qkv.bias.1"; + sd3_name_map[block_prefix + "attn2.to_v.weight"] = dst_prefix + "x_block.attn2.qkv.weight.2"; + sd3_name_map[block_prefix + "attn2.to_v.bias"] = dst_prefix + "x_block.attn2.qkv.bias.2"; + + sd3_name_map[block_prefix + "attn2.add_q_proj.weight"] = dst_prefix + "context_block.attn2.qkv.weight"; + sd3_name_map[block_prefix + "attn2.add_q_proj.bias"] = dst_prefix + "context_block.attn2.qkv.bias"; + sd3_name_map[block_prefix + "attn2.add_k_proj.weight"] = dst_prefix + "context_block.attn2.qkv.weight.1"; + sd3_name_map[block_prefix + "attn2.add_k_proj.bias"] = dst_prefix + "context_block.attn2.qkv.bias.1"; + sd3_name_map[block_prefix + "attn2.add_v_proj.weight"] = dst_prefix + "context_block.attn2.qkv.weight.2"; + sd3_name_map[block_prefix + "attn2.add_v_proj.bias"] = dst_prefix + "context_block.attn2.qkv.bias.2"; + + // norm + sd3_name_map[block_prefix + "attn.norm_q.weight"] = dst_prefix + "x_block.attn.ln_q.weight"; + sd3_name_map[block_prefix + "attn.norm_k.weight"] = dst_prefix + "x_block.attn.ln_k.weight"; + sd3_name_map[block_prefix + "attn.norm_added_q.weight"] = dst_prefix + "context_block.attn.ln_q.weight"; + sd3_name_map[block_prefix + "attn.norm_added_k.weight"] = dst_prefix + "context_block.attn.ln_k.weight"; + + // norm2 + sd3_name_map[block_prefix + "attn2.norm_q.weight"] = dst_prefix + "x_block.attn2.ln_q.weight"; + sd3_name_map[block_prefix + "attn2.norm_k.weight"] = dst_prefix + "x_block.attn2.ln_k.weight"; + + // ff + sd3_name_map[block_prefix + "ff.net.0.proj.weight"] = dst_prefix + "x_block.mlp.fc1.weight"; + sd3_name_map[block_prefix + "ff.net.0.proj.bias"] = dst_prefix + "x_block.mlp.fc1.bias"; + sd3_name_map[block_prefix + "ff.net.2.weight"] = dst_prefix + "x_block.mlp.fc2.weight"; + sd3_name_map[block_prefix + "ff.net.2.bias"] = dst_prefix + "x_block.mlp.fc2.bias"; + + sd3_name_map[block_prefix + "ff_context.net.0.proj.weight"] = dst_prefix + "context_block.mlp.fc1.weight"; + sd3_name_map[block_prefix + "ff_context.net.0.proj.bias"] = dst_prefix + "context_block.mlp.fc1.bias"; + sd3_name_map[block_prefix + "ff_context.net.2.weight"] = dst_prefix + "context_block.mlp.fc2.weight"; + sd3_name_map[block_prefix + "ff_context.net.2.bias"] = dst_prefix + "context_block.mlp.fc2.bias"; + + // output projections + sd3_name_map[block_prefix + "attn.to_out.0.weight"] = dst_prefix + "x_block.attn.proj.weight"; + sd3_name_map[block_prefix + "attn.to_out.0.bias"] = dst_prefix + "x_block.attn.proj.bias"; + sd3_name_map[block_prefix + "attn.to_add_out.weight"] = dst_prefix + "context_block.attn.proj.weight"; + sd3_name_map[block_prefix + "attn.to_add_out.bias"] = dst_prefix + "context_block.attn.proj.bias"; + + // output projections 2 + sd3_name_map[block_prefix + "attn2.to_out.0.weight"] = dst_prefix + "x_block.attn2.proj.weight"; + sd3_name_map[block_prefix + "attn2.to_out.0.bias"] = dst_prefix + "x_block.attn2.proj.bias"; + sd3_name_map[block_prefix + "attn2.to_add_out.weight"] = dst_prefix + "context_block.attn2.proj.weight"; + sd3_name_map[block_prefix + "attn2.to_add_out.bias"] = dst_prefix + "context_block.attn2.proj.bias"; + } + + // --- final layers --- + sd3_name_map["proj_out.weight"] = "final_layer.linear.weight"; + sd3_name_map["proj_out.bias"] = "final_layer.linear.bias"; + sd3_name_map["norm_out.linear.weight"] = "final_layer.adaLN_modulation.1.weight"; + sd3_name_map["norm_out.linear.bias"] = "final_layer.adaLN_modulation.1.bias"; + } + + replace_with_prefix_map(name, sd3_name_map); + + return name; +} + +std::string convert_diffusers_dit_to_original_flux(std::string name) { + int num_layers = 19; + int num_single_layers = 38; + static std::unordered_map flux_name_map; + + if (flux_name_map.empty()) { + // --- time_text_embed --- + flux_name_map["time_text_embed.timestep_embedder.linear_1.weight"] = "time_in.in_layer.weight"; + flux_name_map["time_text_embed.timestep_embedder.linear_1.bias"] = "time_in.in_layer.bias"; + flux_name_map["time_text_embed.timestep_embedder.linear_2.weight"] = "time_in.out_layer.weight"; + flux_name_map["time_text_embed.timestep_embedder.linear_2.bias"] = "time_in.out_layer.bias"; + + flux_name_map["time_text_embed.text_embedder.linear_1.weight"] = "vector_in.in_layer.weight"; + flux_name_map["time_text_embed.text_embedder.linear_1.bias"] = "vector_in.in_layer.bias"; + flux_name_map["time_text_embed.text_embedder.linear_2.weight"] = "vector_in.out_layer.weight"; + flux_name_map["time_text_embed.text_embedder.linear_2.bias"] = "vector_in.out_layer.bias"; + + // guidance + flux_name_map["time_text_embed.guidance_embedder.linear_1.weight"] = "guidance_in.in_layer.weight"; + flux_name_map["time_text_embed.guidance_embedder.linear_1.bias"] = "guidance_in.in_layer.bias"; + flux_name_map["time_text_embed.guidance_embedder.linear_2.weight"] = "guidance_in.out_layer.weight"; + flux_name_map["time_text_embed.guidance_embedder.linear_2.bias"] = "guidance_in.out_layer.bias"; + + // --- context_embedder / x_embedder --- + flux_name_map["context_embedder.weight"] = "txt_in.weight"; + flux_name_map["context_embedder.bias"] = "txt_in.bias"; + flux_name_map["x_embedder.weight"] = "img_in.weight"; + flux_name_map["x_embedder.bias"] = "img_in.bias"; + + // --- double transformer blocks --- + for (int i = 0; i < num_layers; ++i) { + std::string block_prefix = "transformer_blocks." + std::to_string(i) + "."; + std::string dst_prefix = "double_blocks." + std::to_string(i) + "."; + + flux_name_map[block_prefix + "norm1.linear.weight"] = dst_prefix + "img_mod.lin.weight"; + flux_name_map[block_prefix + "norm1.linear.bias"] = dst_prefix + "img_mod.lin.bias"; + flux_name_map[block_prefix + "norm1_context.linear.weight"] = dst_prefix + "txt_mod.lin.weight"; + flux_name_map[block_prefix + "norm1_context.linear.bias"] = dst_prefix + "txt_mod.lin.bias"; + + // attn + flux_name_map[block_prefix + "attn.to_q.weight"] = dst_prefix + "img_attn.qkv.weight"; + flux_name_map[block_prefix + "attn.to_q.bias"] = dst_prefix + "img_attn.qkv.bias"; + flux_name_map[block_prefix + "attn.to_k.weight"] = dst_prefix + "img_attn.qkv.weight.1"; + flux_name_map[block_prefix + "attn.to_k.bias"] = dst_prefix + "img_attn.qkv.bias.1"; + flux_name_map[block_prefix + "attn.to_v.weight"] = dst_prefix + "img_attn.qkv.weight.2"; + flux_name_map[block_prefix + "attn.to_v.bias"] = dst_prefix + "img_attn.qkv.bias.2"; + + flux_name_map[block_prefix + "attn.add_q_proj.weight"] = dst_prefix + "txt_attn.qkv.weight"; + flux_name_map[block_prefix + "attn.add_q_proj.bias"] = dst_prefix + "txt_attn.qkv.bias"; + flux_name_map[block_prefix + "attn.add_k_proj.weight"] = dst_prefix + "txt_attn.qkv.weight.1"; + flux_name_map[block_prefix + "attn.add_k_proj.bias"] = dst_prefix + "txt_attn.qkv.bias.1"; + flux_name_map[block_prefix + "attn.add_v_proj.weight"] = dst_prefix + "txt_attn.qkv.weight.2"; + flux_name_map[block_prefix + "attn.add_v_proj.bias"] = dst_prefix + "txt_attn.qkv.bias.2"; + + // norm + flux_name_map[block_prefix + "attn.norm_q.weight"] = dst_prefix + "img_attn.norm.query_norm.scale"; + flux_name_map[block_prefix + "attn.norm_k.weight"] = dst_prefix + "img_attn.norm.key_norm.scale"; + flux_name_map[block_prefix + "attn.norm_added_q.weight"] = dst_prefix + "txt_attn.norm.query_norm.scale"; + flux_name_map[block_prefix + "attn.norm_added_k.weight"] = dst_prefix + "txt_attn.norm.key_norm.scale"; + + // ff + flux_name_map[block_prefix + "ff.net.0.proj.weight"] = dst_prefix + "img_mlp.0.weight"; + flux_name_map[block_prefix + "ff.net.0.proj.bias"] = dst_prefix + "img_mlp.0.bias"; + flux_name_map[block_prefix + "ff.net.2.weight"] = dst_prefix + "img_mlp.2.weight"; + flux_name_map[block_prefix + "ff.net.2.bias"] = dst_prefix + "img_mlp.2.bias"; + + flux_name_map[block_prefix + "ff_context.net.0.proj.weight"] = dst_prefix + "txt_mlp.0.weight"; + flux_name_map[block_prefix + "ff_context.net.0.proj.bias"] = dst_prefix + "txt_mlp.0.bias"; + flux_name_map[block_prefix + "ff_context.net.2.weight"] = dst_prefix + "txt_mlp.2.weight"; + flux_name_map[block_prefix + "ff_context.net.2.bias"] = dst_prefix + "txt_mlp.2.bias"; + + // output projections + flux_name_map[block_prefix + "attn.to_out.0.weight"] = dst_prefix + "img_attn.proj.weight"; + flux_name_map[block_prefix + "attn.to_out.0.bias"] = dst_prefix + "img_attn.proj.bias"; + flux_name_map[block_prefix + "attn.to_add_out.weight"] = dst_prefix + "txt_attn.proj.weight"; + flux_name_map[block_prefix + "attn.to_add_out.bias"] = dst_prefix + "txt_attn.proj.bias"; + } + + // --- single transformer blocks --- + for (int i = 0; i < num_single_layers; ++i) { + std::string block_prefix = "single_transformer_blocks." + std::to_string(i) + "."; + std::string dst_prefix = "single_blocks." + std::to_string(i) + "."; + + flux_name_map[block_prefix + "norm.linear.weight"] = dst_prefix + "modulation.lin.weight"; + flux_name_map[block_prefix + "norm.linear.bias"] = dst_prefix + "modulation.lin.bias"; + + flux_name_map[block_prefix + "attn.to_q.weight"] = dst_prefix + "linear1.weight"; + flux_name_map[block_prefix + "attn.to_q.bias"] = dst_prefix + "linear1.bias"; + flux_name_map[block_prefix + "attn.to_k.weight"] = dst_prefix + "linear1.weight.1"; + flux_name_map[block_prefix + "attn.to_k.bias"] = dst_prefix + "linear1.bias.1"; + flux_name_map[block_prefix + "attn.to_v.weight"] = dst_prefix + "linear1.weight.2"; + flux_name_map[block_prefix + "attn.to_v.bias"] = dst_prefix + "linear1.bias.2"; + flux_name_map[block_prefix + "proj_mlp.weight"] = dst_prefix + "linear1.weight.3"; + flux_name_map[block_prefix + "proj_mlp.bias"] = dst_prefix + "linear1.bias.3"; + + flux_name_map[block_prefix + "attn.norm_q.weight"] = dst_prefix + "norm.query_norm.scale"; + flux_name_map[block_prefix + "attn.norm_k.weight"] = dst_prefix + "norm.key_norm.scale"; + flux_name_map[block_prefix + "proj_out.weight"] = dst_prefix + "linear2.weight"; + flux_name_map[block_prefix + "proj_out.bias"] = dst_prefix + "linear2.bias"; + } + + // --- final layers --- + flux_name_map["proj_out.weight"] = "final_layer.linear.weight"; + flux_name_map["proj_out.bias"] = "final_layer.linear.bias"; + flux_name_map["norm_out.linear.weight"] = "final_layer.adaLN_modulation.1.weight"; + flux_name_map["norm_out.linear.bias"] = "final_layer.adaLN_modulation.1.bias"; + } + + replace_with_prefix_map(name, flux_name_map); + + return name; +} + +std::string convert_diffusion_model_name(std::string name, std::string prefix, SDVersion version) { + if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) { + name = convert_diffusers_unet_to_original_sd1(name); + } else if (sd_version_is_sdxl(version)) { + name = convert_diffusers_unet_to_original_sdxl(name); + } else if (sd_version_is_sd3(version)) { + name = convert_diffusers_dit_to_original_sd3(name); + } else if (sd_version_is_flux(version)) { + name = convert_diffusers_dit_to_original_flux(name); + } + return name; +} + +std::string convert_diffusers_vae_to_original_sd1(std::string name) { + static const std::vector> vae_conversion_map_base = { + {"nin_shortcut", "conv_shortcut"}, + {"norm_out", "conv_norm_out"}, + {"mid.attn_1.", "mid_block.attentions.0."}, + }; + + static std::vector> vae_conversion_map_layer; + if (vae_conversion_map_layer.empty()) { + for (int i = 0; i < 4; ++i) { + // --- encoder down blocks --- + for (int j = 0; j < 2; ++j) { + std::string hf_down_prefix = "encoder.down_blocks." + std::to_string(i) + ".resnets." + std::to_string(j) + "."; + std::string sd_down_prefix = "encoder.down." + std::to_string(i) + ".block." + std::to_string(j) + "."; + vae_conversion_map_layer.emplace_back(sd_down_prefix, hf_down_prefix); + } + + if (i < 3) { + std::string hf_downsample_prefix = "down_blocks." + std::to_string(i) + ".downsamplers.0."; + std::string sd_downsample_prefix = "down." + std::to_string(i) + ".downsample."; + vae_conversion_map_layer.emplace_back(sd_downsample_prefix, hf_downsample_prefix); + + std::string hf_upsample_prefix = "up_blocks." + std::to_string(i) + ".upsamplers.0."; + std::string sd_upsample_prefix = "up." + std::to_string(3 - i) + ".upsample."; + vae_conversion_map_layer.emplace_back(sd_upsample_prefix, hf_upsample_prefix); + } + + // --- decoder up blocks (reverse) --- + for (int j = 0; j < 3; ++j) { + std::string hf_up_prefix = "decoder.up_blocks." + std::to_string(i) + ".resnets." + std::to_string(j) + "."; + std::string sd_up_prefix = "decoder.up." + std::to_string(3 - i) + ".block." + std::to_string(j) + "."; + vae_conversion_map_layer.emplace_back(sd_up_prefix, hf_up_prefix); + } + } + + // --- mid block (encoder + decoder) --- + for (int i = 0; i < 2; ++i) { + std::string hf_mid_res_prefix = "mid_block.resnets." + std::to_string(i) + "."; + std::string sd_mid_res_prefix = "mid.block_" + std::to_string(i + 1) + "."; + vae_conversion_map_layer.emplace_back(sd_mid_res_prefix, hf_mid_res_prefix); + } + } + + static const std::vector> vae_conversion_map_attn = { + {"norm.", "group_norm."}, + {"q.", "query."}, + {"k.", "key."}, + {"v.", "value."}, + {"proj_out.", "proj_attn."}, + }; + + static const std::vector> vae_extra_conversion_map = { + {"to_q", "q"}, + {"to_k", "k"}, + {"to_v", "v"}, + {"to_out.0", "proj_out"}, + }; + + std::string result = name; + + for (const auto& p : vae_conversion_map_base) { + size_t pos = result.find(p.second); + if (pos != std::string::npos) { + result.replace(pos, p.second.size(), p.first); + } + } + + for (const auto& p : vae_conversion_map_layer) { + size_t pos = result.find(p.second); + if (pos != std::string::npos) { + result.replace(pos, p.second.size(), p.first); + } + } + + if (name.find("attentions") != std::string::npos) { + for (const auto& p : vae_conversion_map_attn) { + size_t pos = result.find(p.second); + if (pos != std::string::npos) { + result.replace(pos, p.second.size(), p.first); + } + } + } + + if (result.find("mid.attn_1.") != std::string::npos) { + for (const auto& p : vae_extra_conversion_map) { + size_t pos = result.find(p.first); + if (pos != std::string::npos) { + result.replace(pos, p.first.size(), p.second); + } + } + } + + return result; +} + +std::string convert_first_stage_model_name(std::string name, std::string prefix) { + name = convert_diffusers_vae_to_original_sd1(name); + return name; +} + +std::string convert_pmid_name(const std::string& name) { + static std::unordered_map pmid_name_map = { + {"pmid.vision_model.visual_projection.weight", "pmid.visual_projection.weight"}, + }; + if (pmid_name_map.find(name) != pmid_name_map.end()) { + return pmid_name_map[name]; + } + return name; +} + +std::string convert_pmid_v2_name(const std::string& name) { + static std::unordered_map pmid_v2_name_map = { + {"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.token_proj.0.bias", + "pmid.qformer_perceiver.token_proj.fc1.bias"}, + {"pmid.qformer_perceiver.token_proj.2.bias", + "pmid.qformer_perceiver.token_proj.fc2.bias"}, + {"pmid.qformer_perceiver.token_proj.0.weight", + "pmid.qformer_perceiver.token_proj.fc1.weight"}, + {"pmid.qformer_perceiver.token_proj.2.weight", + "pmid.qformer_perceiver.token_proj.fc2.weight"}, + }; + if (pmid_v2_name_map.find(name) != pmid_v2_name_map.end()) { + return pmid_v2_name_map[name]; + } + return name; +} + +std::string convert_sep_to_dot(std::string name) { + const std::vector protected_tokens = { + "self_attn", + "out_proj", + "q_proj", + "k_proj", + "v_proj", + "to_k", + "to_q", + "to_v", + "to_out", + "text_model", + "down_blocks", + "mid_block", + "up_block", + "proj_in", + "proj_out", + "transformer_blocks", + "single_transformer_blocks", + "diffusion_model", + "cond_stage_model", + "first_stage_model", + "conv_in", + "conv_out", + "lora_down", + "lora_up", + "diff_b", + "hada_w1_a", + "hada_w1_b", + "hada_w2_a", + "hada_w2_b", + "hada_t1", + "hada_t2", + ".lokr_w1", + ".lokr_w1_a", + ".lokr_w1_b", + ".lokr_w2", + ".lokr_w2_a", + ".lokr_w2_b", + "time_emb_proj", + "conv_shortcut", + "time_embedding", + "conv_norm_out", + "double_blocks", + "txt_attn", + "img_attn", + "input_blocks", + "output_blocks", + "middle_block", + "skip_connection", + "emb_layers", + "in_layers", + "out_layers", + "add_q_proj", + "add_k_proj", + "add_v_proj", + "add_out_proj", + "ff_context", + "norm_added_q", + "norm_added_v", + "to_add_out"}; + + // record the positions of underscores that should NOT be replaced + std::unordered_set protected_positions; + + for (const auto& token : protected_tokens) { + size_t start = 0; + while ((start = name.find(token, start)) != std::string::npos) { + size_t local_pos = token.find('_'); + while (local_pos != std::string::npos) { + protected_positions.insert(start + local_pos); + local_pos = token.find('_', local_pos + 1); + } + start += token.size(); + } + } + + for (size_t i = 0; i < name.size(); ++i) { + if (name[i] == '_' && !protected_positions.count(i)) { + name[i] = '.'; + } + } + + return name; +} + +std::string convert_tensor_name(std::string name, SDVersion version) { + bool is_lora = false; + bool is_lycoris_underline = false; + std::vector lora_prefix_vec = { + "lora.lora.", + "lora.lora_", + "lora.lycoris_", + "lora.lycoris.", + "lora.", + }; + for (const auto& prefix : lora_prefix_vec) { + if (starts_with(name, prefix)) { + is_lora = true; + name = name.substr(prefix.size()); + if (contains(prefix, "lycoris_")) { + is_lycoris_underline = true; + } + break; + } + } + // preprocess lora tensor name + if (is_lora) { + std::map lora_suffix_map = { + {".lora_down.weight", ".weight.lora_down"}, + {".lora_up.weight", ".weight.lora_up"}, + {".lora.down.weight", ".weight.lora_down"}, + {".lora.up.weight", ".weight.lora_up"}, + {"_lora.down.weight", ".weight.lora_down"}, + {"_lora.up.weight", ".weight.lora_up"}, + {".lora_A.weight", ".weight.lora_down"}, + {".lora_B.weight", ".weight.lora_up"}, + {".lora_A.default.weight", ".weight.lora_down"}, + {".lora_B.default.weight", ".weight.lora_up"}, + {".lora_linear", ".weight.alpha"}, + {".alpha", ".weight.alpha"}, + {".scale", ".weight.scale"}, + {".diff", ".weight.diff"}, + {".diff_b", ".bias.diff"}, + {".hada_w1_a", ".weight.hada_w1_a"}, + {".hada_w1_b", ".weight.hada_w1_b"}, + {".hada_w2_a", ".weight.hada_w2_a"}, + {".hada_w2_b", ".weight.hada_w2_b"}, + {".hada_t1", ".weight.hada_t1"}, + {".hada_t2", ".weight.hada_t2"}, + {".lokr_w1", ".weight.lokr_w1"}, + {".lokr_w1_a", ".weight.lokr_w1_a"}, + {".lokr_w1_b", ".weight.lokr_w1_b"}, + {".lokr_w2", ".weight.lokr_w2"}, + {".lokr_w2_a", ".weight.lokr_w2_a"}, + {".lokr_w2_b", ".weight.lokr_w2_b"}, + }; + + for (const auto& [old_suffix, new_suffix] : lora_suffix_map) { + if (ends_with(name, old_suffix)) { + name.replace(name.size() - old_suffix.size(), old_suffix.size(), new_suffix); + break; + } + } + + size_t pos = name.find(".processor"); + if (pos != std::string::npos) { + name.replace(pos, strlen(".processor"), ""); + } + + std::vector dit_prefix_vec = { + "transformer_blocks", + "single_transformer_blocks", + }; + for (const auto& prefix : dit_prefix_vec) { + if (starts_with(name, prefix)) { + name = "transformer." + name; + break; + } + } + + if (sd_version_is_unet(version) || is_lycoris_underline) { + name = convert_sep_to_dot(name); + } + } + + std::vector> prefix_map = { + {"diffusion_model.", "model.diffusion_model."}, + {"unet.", "model.diffusion_model."}, + {"transformer.", "model.diffusion_model."}, // dit + {"vae.", "first_stage_model."}, + {"text_encoder.", "cond_stage_model.transformer."}, + {"te.", "cond_stage_model.transformer."}, + {"text_encoder.2.", "cond_stage_model.1.transformer."}, + {"conditioner.embedders.0.open_clip.", "cond_stage_model."}, + // https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 + {"conditioner.embedders.0.", "cond_stage_model."}, + {"conditioner.embedders.1.", "cond_stage_model.1."}, + // {"te2.text_model.encoder.layers.", "cond_stage_model.1.model.transformer.resblocks."}, + {"te2.", "cond_stage_model.1.transformer."}, + {"te1.", "cond_stage_model.transformer."}, + }; + + replace_with_prefix_map(name, prefix_map); + + // diffusion model + { + std::vector diffuison_model_prefix_vec = { + "model.diffusion_model.", + }; + for (const auto& prefix : diffuison_model_prefix_vec) { + if (starts_with(name, prefix)) { + name = convert_diffusion_model_name(name.substr(prefix.size()), prefix, version); + name = prefix + name; + break; + } + } + } + + // cond_stage_model + { + std::vector cond_stage_model_prefix_vec = { + "cond_stage_model.1.", + "cond_stage_model.", + "conditioner.embedders.", + "text_encoders.", + }; + for (const auto& prefix : cond_stage_model_prefix_vec) { + if (starts_with(name, prefix)) { + name = convert_cond_stage_model_name(name.substr(prefix.size()), prefix); + name = prefix + name; + break; + } + } + } + + // first_stage_model + { + std::vector first_stage_model_prefix_vec = { + "first_stage_model.", + "vae.", + }; + for (const auto& prefix : first_stage_model_prefix_vec) { + if (starts_with(name, prefix)) { + name = convert_first_stage_model_name(name.substr(prefix.size()), prefix); + name = prefix + name; + break; + } + } + } + + // pmid + { + if (starts_with(name, "pmid.")) { + name = convert_pmid_name(name); + } + if (starts_with(name, "pmid.qformer_perceiver")) { + name = convert_pmid_v2_name(name); + } + } + + // controlnet + { + if (starts_with(name, "control_model.")) { // for controlnet pth models + size_t pos = name.find('.'); + if (pos != std::string::npos) { + name = name.substr(pos + 1); + } + } + } + + if (is_lora) { + name = "lora." + name; + } + + return name; +} diff --git a/name_conversion.h b/name_conversion.h new file mode 100644 index 00000000..eb3d1a9b --- /dev/null +++ b/name_conversion.h @@ -0,0 +1,10 @@ +#ifndef __NAME_CONVERSTION_H__ +#define __NAME_CONVERSTION_H__ + +#include + +#include "model.h" + +std::string convert_tensor_name(std::string name, SDVersion version); + +#endif // __NAME_CONVERSTION_H__ \ No newline at end of file diff --git a/ordered_map.hpp b/ordered_map.hpp new file mode 100644 index 00000000..3fbdca5d --- /dev/null +++ b/ordered_map.hpp @@ -0,0 +1,177 @@ +#ifndef __ORDERED_MAP_HPP__ +#define __ORDERED_MAP_HPP__ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +template +class OrderedMap { +public: + using key_type = Key; + using mapped_type = T; + using value_type = std::pair; + using list_type = std::list; + using size_type = typename list_type::size_type; + using difference_type = typename list_type::difference_type; + using iterator = typename list_type::iterator; + using const_iterator = typename list_type::const_iterator; + +private: + list_type data_; + std::unordered_map index_; + +public: + // --- constructors --- + OrderedMap() = default; + + OrderedMap(std::initializer_list init) { + for (const auto& kv : init) + insert(kv); + } + + OrderedMap(const OrderedMap&) = default; + OrderedMap(OrderedMap&&) noexcept = default; + OrderedMap& operator=(const OrderedMap&) = default; + OrderedMap& operator=(OrderedMap&&) noexcept = default; + + // --- element access --- + T& at(const Key& key) { + auto it = index_.find(key); + if (it == index_.end()) + throw std::out_of_range("OrderedMap::at: key not found"); + return it->second->second; + } + + const T& at(const Key& key) const { + auto it = index_.find(key); + if (it == index_.end()) + throw std::out_of_range("OrderedMap::at: key not found"); + return it->second->second; + } + + T& operator[](const Key& key) { + auto it = index_.find(key); + if (it == index_.end()) { + data_.emplace_back(key, T{}); + auto iter = std::prev(data_.end()); + index_[key] = iter; + return iter->second; + } + return it->second->second; + } + + // --- iterators --- + iterator begin() noexcept { return data_.begin(); } + const_iterator begin() const noexcept { return data_.begin(); } + const_iterator cbegin() const noexcept { return data_.cbegin(); } + + iterator end() noexcept { return data_.end(); } + const_iterator end() const noexcept { return data_.end(); } + const_iterator cend() const noexcept { return data_.cend(); } + + // --- capacity --- + bool empty() const noexcept { return data_.empty(); } + size_type size() const noexcept { return data_.size(); } + + // --- modifiers --- + void clear() noexcept { + data_.clear(); + index_.clear(); + } + + std::pair insert(const value_type& value) { + auto it = index_.find(value.first); + if (it != index_.end()) { + return {it->second, false}; + } + data_.push_back(value); + auto iter = std::prev(data_.end()); + index_[value.first] = iter; + return {iter, true}; + } + + std::pair insert(value_type&& value) { + auto it = index_.find(value.first); + if (it != index_.end()) { + return {it->second, false}; + } + data_.push_back(std::move(value)); + auto iter = std::prev(data_.end()); + index_[iter->first] = iter; + return {iter, true}; + } + + void erase(const Key& key) { + auto it = index_.find(key); + if (it != index_.end()) { + data_.erase(it->second); + index_.erase(it); + } + } + + iterator erase(iterator pos) { + index_.erase(pos->first); + return data_.erase(pos); + } + + // --- lookup --- + size_type count(const Key& key) const { + return index_.count(key); + } + + iterator find(const Key& key) { + auto it = index_.find(key); + if (it == index_.end()) + return data_.end(); + return it->second; + } + + const_iterator find(const Key& key) const { + auto it = index_.find(key); + if (it == index_.end()) + return data_.end(); + return it->second; + } + + bool contains(const Key& key) const { + return index_.find(key) != index_.end(); + } + + // --- comparison --- + bool operator==(const OrderedMap& other) const { + return data_ == other.data_; + } + + bool operator!=(const OrderedMap& other) const { + return !(*this == other); + } + + template + std::pair emplace(Args&&... args) { + value_type value(std::forward(args)...); + auto it = index_.find(value.first); + if (it != index_.end()) { + return {it->second, false}; + } + data_.push_back(std::move(value)); + auto iter = std::prev(data_.end()); + index_[iter->first] = iter; + return {iter, true}; + } + + void swap(OrderedMap& other) noexcept { + data_.swap(other.data_); + index_.swap(other.index_); + } +}; + +#endif // __ORDERED_MAP_HPP__ \ No newline at end of file diff --git a/pmid.hpp b/pmid.hpp index 51e8fb76..70d8059c 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -578,7 +578,7 @@ struct PhotoMakerIDEmbed : public GGMLRunner { const std::string& file_path = "", const std::string& prefix = "") : file_path(file_path), GGMLRunner(backend, offload_params_to_cpu), model_loader(ml) { - if (!model_loader->init_from_file(file_path, prefix)) { + if (!model_loader->init_from_file_and_convert_name(file_path, prefix)) { load_failed = true; } } diff --git a/qwen_image.hpp b/qwen_image.hpp index ca3c84ac..87d2fb9b 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -644,7 +644,7 @@ namespace Qwen { ggml_type model_data_type = GGML_TYPE_Q8_0; ModelLoader model_loader; - if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { + if (!model_loader.init_from_file_and_convert_name(file_path, "model.diffusion_model.")) { LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); return; } diff --git a/qwenvl.hpp b/qwenvl.hpp index 26d18623..0a914f6c 100644 --- a/qwenvl.hpp +++ b/qwenvl.hpp @@ -1342,7 +1342,7 @@ namespace Qwen { ggml_type model_data_type = GGML_TYPE_F16; ModelLoader model_loader; - if (!model_loader.init_from_file(file_path, "qwen2vl.")) { + if (!model_loader.init_from_file_and_convert_name(file_path, "qwen2vl.")) { LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); return; } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 90f005fe..4cea83a1 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -278,6 +278,8 @@ class StableDiffusionGGML { } } + model_loader.convert_tensors_name(); + version = model_loader.get_sd_version(); if (version == VERSION_COUNT) { LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path)); @@ -569,13 +571,13 @@ class StableDiffusionGGML { version); } if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) { - pmid_lora = std::make_shared(backend, sd_ctx_params->photo_maker_path, ""); + pmid_lora = std::make_shared(backend, sd_ctx_params->photo_maker_path, "", version); if (!pmid_lora->load_from_file(true, n_threads)) { LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path); return false; } LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->photo_maker_path); - if (!model_loader.init_from_file(sd_ctx_params->photo_maker_path, "pmid.")) { + if (!model_loader.init_from_file_and_convert_name(sd_ctx_params->photo_maker_path, "pmid.")) { LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path); } else { stacked_id = true; @@ -609,7 +611,7 @@ class StableDiffusionGGML { ignore_tensors.insert("first_stage_model."); } if (stacked_id) { - ignore_tensors.insert("lora."); + ignore_tensors.insert("pmid.unet."); } if (vae_decode_only) { @@ -925,7 +927,7 @@ class StableDiffusionGGML { LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); return; } - LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : ""); + LoraModel lora(backend, file_path, is_high_noise ? "model.high_noise_" : "", version); if (!lora.load_from_file(false, n_threads)) { LOG_WARN("load lora tensors from %s failed", file_path.c_str()); return; diff --git a/t5.hpp b/t5.hpp index 89a60665..4cc8e124 100644 --- a/t5.hpp +++ b/t5.hpp @@ -1004,7 +1004,7 @@ struct T5Embedder { ggml_type model_data_type = GGML_TYPE_F16; ModelLoader model_loader; - if (!model_loader.init_from_file(file_path)) { + if (!model_loader.init_from_file_and_convert_name(file_path)) { LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); return; } diff --git a/tae.hpp b/tae.hpp index 14cdb578..568e409a 100644 --- a/tae.hpp +++ b/tae.hpp @@ -222,7 +222,7 @@ struct TinyAutoEncoder : public GGMLRunner { } ModelLoader model_loader; - if (!model_loader.init_from_file(file_path)) { + if (!model_loader.init_from_file_and_convert_name(file_path)) { LOG_ERROR("init taesd model loader from file failed: '%s'", file_path.c_str()); return false; } diff --git a/upscaler.cpp b/upscaler.cpp index 74048a1a..62c0d29a 100644 --- a/upscaler.cpp +++ b/upscaler.cpp @@ -42,7 +42,7 @@ struct UpscalerGGML { backend = ggml_backend_sycl_init(0); #endif ModelLoader model_loader; - if (!model_loader.init_from_file(esrgan_path)) { + if (!model_loader.init_from_file_and_convert_name(esrgan_path)) { LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str()); } model_loader.set_wtype_override(model_data_type); diff --git a/wan.hpp b/wan.hpp index 9720cc63..91a2e920 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1271,7 +1271,7 @@ namespace WAN { vae->get_param_tensors(tensors, "first_stage_model"); ModelLoader model_loader; - if (!model_loader.init_from_file(file_path, "vae.")) { + if (!model_loader.init_from_file_and_convert_name(file_path, "vae.")) { LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); return; } @@ -2255,7 +2255,7 @@ namespace WAN { LOG_INFO("loading from '%s'", file_path.c_str()); ModelLoader model_loader; - if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) { + if (!model_loader.init_from_file_and_convert_name(file_path, "model.diffusion_model.")) { LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); return; }