Skip to content

Commit 347710f

Browse files
authored
feat: support applying LoRA at runtime (leejet#969)
1 parent 59ebdf0 commit 347710f

21 files changed

+896
-222
lines changed

clip.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -936,7 +936,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
936936
size_t max_token_idx = 0,
937937
bool return_pooled = false,
938938
int clip_skip = -1) {
939-
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
939+
struct ggml_cgraph* gf = new_graph_custom(2048);
940940

941941
input_ids = to_backend(input_ids);
942942

common.hpp

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -182,31 +182,21 @@ class GEGLU : public UnaryBlock {
182182
int64_t dim_in;
183183
int64_t dim_out;
184184

185-
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
186-
enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_storage_map, GGML_TYPE_F32);
187-
enum ggml_type bias_wtype = GGML_TYPE_F32;
188-
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
189-
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
190-
}
191-
192185
public:
193186
GEGLU(int64_t dim_in, int64_t dim_out)
194-
: dim_in(dim_in), dim_out(dim_out) {}
187+
: dim_in(dim_in), dim_out(dim_out) {
188+
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim_in, dim_out * 2));
189+
}
195190

196191
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
197192
// x: [ne3, ne2, ne1, dim_in]
198193
// return: [ne3, ne2, ne1, dim_out]
199-
struct ggml_tensor* w = params["proj.weight"];
200-
struct ggml_tensor* b = params["proj.bias"];
201-
202-
auto x_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in]
203-
auto x_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in]
204-
auto gate_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ]
205-
auto gate_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ]
194+
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
206195

207-
auto x_in = x;
208-
x = ggml_ext_linear(ctx->ggml_ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out]
209-
auto gate = ggml_ext_linear(ctx->ggml_ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out]
196+
x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2]
197+
auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0);
198+
x = x_vec[0]; // [ne3, ne2, ne1, dim_out]
199+
auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out]
210200

211201
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
212202

conditioner.hpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ struct Conditioner {
3434
virtual void free_params_buffer() = 0;
3535
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
3636
virtual size_t get_params_buffer_size() = 0;
37+
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
3738
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
3839
int n_threads,
3940
const ConditionerParams& conditioner_params) {
@@ -108,6 +109,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
108109
return buffer_size;
109110
}
110111

112+
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
113+
text_model->set_weight_adapter(adapter);
114+
if (sd_version_is_sdxl(version)) {
115+
text_model2->set_weight_adapter(adapter);
116+
}
117+
}
118+
111119
bool load_embedding(std::string embd_name, std::string embd_path, std::vector<int32_t>& bpe_tokens) {
112120
// the order matters
113121
ModelLoader model_loader;
@@ -764,6 +772,18 @@ struct SD3CLIPEmbedder : public Conditioner {
764772
return buffer_size;
765773
}
766774

775+
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
776+
if (clip_l) {
777+
clip_l->set_weight_adapter(adapter);
778+
}
779+
if (clip_g) {
780+
clip_g->set_weight_adapter(adapter);
781+
}
782+
if (t5) {
783+
t5->set_weight_adapter(adapter);
784+
}
785+
}
786+
767787
std::vector<std::pair<std::vector<int>, std::vector<float>>> tokenize(std::string text,
768788
size_t max_length = 0,
769789
bool padding = false) {
@@ -1160,6 +1180,15 @@ struct FluxCLIPEmbedder : public Conditioner {
11601180
return buffer_size;
11611181
}
11621182

1183+
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
1184+
if (clip_l) {
1185+
clip_l->set_weight_adapter(adapter);
1186+
}
1187+
if (t5) {
1188+
t5->set_weight_adapter(adapter);
1189+
}
1190+
}
1191+
11631192
std::vector<std::pair<std::vector<int>, std::vector<float>>> tokenize(std::string text,
11641193
size_t max_length = 0,
11651194
bool padding = false) {
@@ -1400,6 +1429,12 @@ struct T5CLIPEmbedder : public Conditioner {
14001429
return buffer_size;
14011430
}
14021431

1432+
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
1433+
if (t5) {
1434+
t5->set_weight_adapter(adapter);
1435+
}
1436+
}
1437+
14031438
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
14041439
size_t max_length = 0,
14051440
bool padding = false) {
@@ -1589,6 +1624,12 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
15891624
return buffer_size;
15901625
}
15911626

1627+
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
1628+
if (qwenvl) {
1629+
qwenvl->set_weight_adapter(adapter);
1630+
}
1631+
}
1632+
15921633
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
15931634
size_t max_length = 0,
15941635
size_t system_prompt_length = 0,

control.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ struct ControlNet : public GGMLRunner {
380380
struct ggml_tensor* timesteps,
381381
struct ggml_tensor* context,
382382
struct ggml_tensor* y = nullptr) {
383-
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, CONTROL_NET_GRAPH_SIZE, false);
383+
struct ggml_cgraph* gf = new_graph_custom(CONTROL_NET_GRAPH_SIZE);
384384

385385
x = to_backend(x);
386386
if (guided_hint_cached) {

diffusion_model.hpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ struct DiffusionModel {
3535
virtual void free_compute_buffer() = 0;
3636
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
3737
virtual size_t get_params_buffer_size() = 0;
38-
virtual int64_t get_adm_in_channels() = 0;
39-
virtual void set_flash_attn_enabled(bool enabled) = 0;
38+
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
39+
virtual int64_t get_adm_in_channels() = 0;
40+
virtual void set_flash_attn_enabled(bool enabled) = 0;
4041
};
4142

4243
struct UNetModel : public DiffusionModel {
@@ -73,6 +74,10 @@ struct UNetModel : public DiffusionModel {
7374
return unet.get_params_buffer_size();
7475
}
7576

77+
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
78+
unet.set_weight_adapter(adapter);
79+
}
80+
7681
int64_t get_adm_in_channels() override {
7782
return unet.unet.adm_in_channels;
7883
}
@@ -130,6 +135,10 @@ struct MMDiTModel : public DiffusionModel {
130135
return mmdit.get_params_buffer_size();
131136
}
132137

138+
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
139+
mmdit.set_weight_adapter(adapter);
140+
}
141+
133142
int64_t get_adm_in_channels() override {
134143
return 768 + 1280;
135144
}
@@ -188,6 +197,10 @@ struct FluxModel : public DiffusionModel {
188197
return flux.get_params_buffer_size();
189198
}
190199

200+
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
201+
flux.set_weight_adapter(adapter);
202+
}
203+
191204
int64_t get_adm_in_channels() override {
192205
return 768;
193206
}
@@ -251,6 +264,10 @@ struct WanModel : public DiffusionModel {
251264
return wan.get_params_buffer_size();
252265
}
253266

267+
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
268+
wan.set_weight_adapter(adapter);
269+
}
270+
254271
int64_t get_adm_in_channels() override {
255272
return 768;
256273
}
@@ -313,6 +330,10 @@ struct QwenImageModel : public DiffusionModel {
313330
return qwen_image.get_params_buffer_size();
314331
}
315332

333+
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
334+
qwen_image.set_weight_adapter(adapter);
335+
}
336+
316337
int64_t get_adm_in_channels() override {
317338
return 768;
318339
}

docs/lora.md

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,15 @@ Here's a simple example:
1212

1313
`../models/marblesh.safetensors` or `../models/marblesh.ckpt` will be applied to the model
1414

15-
# Support matrix
16-
17-
> ℹ️ CUDA `get_rows` support is defined here:
18-
> [ggml-org/ggml/src/ggml-cuda/getrows.cu#L156](https://github.com/ggml-org/ggml/blob/7dee1d6a1e7611f238d09be96738388da97c88ed/src/ggml-cuda/getrows.cu#L156)
19-
> Currently only the basic types + Q4/Q5/Q8 are implemented. K-quants are **not** supported.
20-
21-
NOTE: The other backends may have different support.
22-
23-
| Quant / Type | CUDA | Vulkan |
24-
|--------------|------|--------|
25-
| F32 | ✔️ | ✔️ |
26-
| F16 | ✔️ | ✔️ |
27-
| BF16 | ✔️ | ✔️ |
28-
| I32 | ✔️ ||
29-
| Q4_0 | ✔️ | ✔️ |
30-
| Q4_1 | ✔️ | ✔️ |
31-
| Q5_0 | ✔️ | ✔️ |
32-
| Q5_1 | ✔️ | ✔️ |
33-
| Q8_0 | ✔️ | ✔️ |
34-
| Q2_K |||
35-
| Q3_K |||
36-
| Q4_K |||
37-
| Q5_K |||
38-
| Q6_K |||
39-
| Q8_K |||
40-
| IQ1_S || ✔️ |
41-
| IQ1_M || ✔️ |
42-
| IQ2_XXS || ✔️ |
43-
| IQ2_XS || ✔️ |
44-
| IQ2_S || ✔️ |
45-
| IQ3_XXS || ✔️ |
46-
| IQ3_S || ✔️ |
47-
| IQ4_XS || ✔️ |
48-
| IQ4_NL || ✔️ |
49-
| MXFP4 || ✔️ |
15+
# Lora Apply Mode
16+
17+
There are two ways to apply LoRA: **immediately** and **at_runtime**. You can specify it using the `--lora-apply-mode` parameter.
18+
19+
By default, the mode is selected automatically:
20+
21+
* If the model weights contain any quantized parameters, the **at_runtime** mode is used;
22+
* Otherwise, the **immediately** mode is used.
23+
24+
The **immediately** mode may have precision and compatibility issues with quantized parameters, but it usually offers faster inference speed and, in some cases, lower memory usage.
25+
In contrast, the **at_runtime** mode provides better compatibility and higher precision, but inference may be slower and memory usage may be higher in some cases.
26+

esrgan.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ struct ESRGAN : public GGMLRunner {
344344
if (!rrdb_net)
345345
return nullptr;
346346
constexpr int kGraphNodes = 1 << 16; // 65k
347-
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, kGraphNodes, /*grads*/ false);
347+
struct ggml_cgraph* gf = new_graph_custom(kGraphNodes);
348348
x = to_backend(x);
349349

350350
auto runner_ctx = get_context();

examples/cli/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ Options:
9999
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
100100
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise)
101101
--prediction prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow]
102+
--lora-apply-mode the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. In auto mode, if the model weights
103+
contain any quantized parameters, the at_runtime mode will be used; otherwise,
104+
immediately will be used.The immediately mode may have precision and
105+
compatibility issues with quantized parameters, but it usually offers faster inference
106+
speed and, in some cases, lower memory usageThe at_runtime mode, on the other
107+
hand, is exactly the opposite.
102108
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default:
103109
discrete
104110
--skip-layers layers to skip for SLG steps (default: [7,8,9])

examples/cli/main.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ struct SDParams {
137137
int chroma_t5_mask_pad = 1;
138138
float flow_shift = INFINITY;
139139

140-
prediction_t prediction = DEFAULT_PRED;
140+
prediction_t prediction = DEFAULT_PRED;
141+
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
141142

142143
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
143144
bool force_sdxl_vae_conv_scale = false;
@@ -209,6 +210,7 @@ void print_params(SDParams params) {
209210
printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str));
210211
printf(" moe_boundary: %.3f\n", params.moe_boundary);
211212
printf(" prediction: %s\n", sd_prediction_name(params.prediction));
213+
printf(" lora_apply_mode: %s\n", sd_lora_apply_mode_name(params.lora_apply_mode));
212214
printf(" flow_shift: %.2f\n", params.flow_shift);
213215
printf(" strength(img2img): %.2f\n", params.strength);
214216
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
@@ -926,6 +928,20 @@ void parse_args(int argc, const char** argv, SDParams& params) {
926928
return 1;
927929
};
928930

931+
auto on_lora_apply_mode_arg = [&](int argc, const char** argv, int index) {
932+
if (++index >= argc) {
933+
return -1;
934+
}
935+
const char* arg = argv[index];
936+
params.lora_apply_mode = str_to_lora_apply_mode(arg);
937+
if (params.lora_apply_mode == LORA_APPLY_MODE_COUNT) {
938+
fprintf(stderr, "error: invalid lora apply model %s\n",
939+
arg);
940+
return -1;
941+
}
942+
return 1;
943+
};
944+
929945
auto on_sample_method_arg = [&](int argc, const char** argv, int index) {
930946
if (++index >= argc) {
931947
return -1;
@@ -1123,6 +1139,14 @@ void parse_args(int argc, const char** argv, SDParams& params) {
11231139
"--prediction",
11241140
"prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow]",
11251141
on_prediction_arg},
1142+
{"",
1143+
"--lora-apply-mode",
1144+
"the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. "
1145+
"In auto mode, if the model weights contain any quantized parameters, the at_runtime mode will be used; otherwise, immediately will be used."
1146+
"The immediately mode may have precision and compatibility issues with quantized parameters, "
1147+
"but it usually offers faster inference speed and, in some cases, lower memory usage"
1148+
"The at_runtime mode, on the other hand, is exactly the opposite.",
1149+
on_lora_apply_mode_arg},
11261150
{"",
11271151
"--scheduler",
11281152
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple], default: discrete",
@@ -1738,6 +1762,7 @@ int main(int argc, const char* argv[]) {
17381762
params.wtype,
17391763
params.rng_type,
17401764
params.prediction,
1765+
params.lora_apply_mode,
17411766
params.offload_params_to_cpu,
17421767
params.clip_on_cpu,
17431768
params.control_net_cpu,

flux.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,7 @@ namespace Flux {
12431243
bool increase_ref_index = false,
12441244
std::vector<int> skip_layers = {}) {
12451245
GGML_ASSERT(x->ne[3] == 1);
1246-
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
1246+
struct ggml_cgraph* gf = new_graph_custom(FLUX_GRAPH_SIZE);
12471247

12481248
struct ggml_tensor* mod_index_arange = nullptr;
12491249
struct ggml_tensor* dct = nullptr; // for chroma radiance

0 commit comments

Comments
 (0)