@@ -443,6 +443,10 @@ class StableDiffusionGGML {
443443 diffusion_model->alloc_params_buffer ();
444444 diffusion_model->get_param_tensors (tensors);
445445
446+ if (sd_version_is_unet_edit (version)) {
447+ vae_decode_only = false ;
448+ }
449+
446450 if (high_noise_diffusion_model) {
447451 high_noise_diffusion_model->alloc_params_buffer ();
448452 high_noise_diffusion_model->get_param_tensors (tensors);
@@ -748,15 +752,15 @@ class StableDiffusionGGML {
748752 denoiser->scheduler ->version = version;
749753 break ;
750754 case SGM_UNIFORM:
751- LOG_INFO (" Running with SGM Uniform schedule" );
752- denoiser->scheduler = std::make_shared<SGMUniformSchedule>();
753- denoiser->scheduler ->version = version;
754- break ;
755+ LOG_INFO (" Running with SGM Uniform schedule" );
756+ denoiser->scheduler = std::make_shared<SGMUniformSchedule>();
757+ denoiser->scheduler ->version = version;
758+ break ;
755759 case SIMPLE:
756- LOG_INFO (" Running with Simple schedule" );
757- denoiser->scheduler = std::make_shared<SimpleSchedule>();
758- denoiser->scheduler ->version = version;
759- break ;
760+ LOG_INFO (" Running with Simple schedule" );
761+ denoiser->scheduler = std::make_shared<SimpleSchedule>();
762+ denoiser->scheduler ->version = version;
763+ break ;
760764 case SMOOTHSTEP:
761765 LOG_INFO (" Running with SmoothStep scheduler" );
762766 denoiser->scheduler = std::make_shared<SmoothStepSchedule>();
@@ -1053,7 +1057,7 @@ class StableDiffusionGGML {
10531057 ggml_tensor* denoise_mask = NULL ,
10541058 ggml_tensor* vace_context = NULL ,
10551059 float vace_strength = 1 .f) {
1056- if (shifted_timestep > 0 && !sd_version_is_sdxl (version)) {
1060+ if (shifted_timestep > 0 && !sd_version_is_sdxl (version)) {
10571061 LOG_WARN (" timestep shifting is only supported for SDXL models!" );
10581062 shifted_timestep = 0 ;
10591063 }
@@ -1127,7 +1131,7 @@ class StableDiffusionGGML {
11271131 } else {
11281132 timesteps_vec.assign (1 , t);
11291133 }
1130-
1134+
11311135 timesteps_vec = process_timesteps (timesteps_vec, init_latent, denoise_mask);
11321136 auto timesteps = vector_to_ggml_tensor (work_ctx, timesteps_vec);
11331137 std::vector<float > guidance_vec (1 , guidance.distilled_guidance );
@@ -2387,19 +2391,35 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
23872391 init_latent = generate_init_latent (sd_ctx, work_ctx, width, height);
23882392 }
23892393
2390- if (sd_img_gen_params->ref_images_count > 0 ) {
2394+ sd_guidance_params_t guidance = sd_img_gen_params->sample_params .guidance ;
2395+ std::vector<sd_image_t *> ref_images;
2396+ for (int i = 0 ; i < sd_img_gen_params->ref_images_count ; i++) {
2397+ ref_images.push_back (&sd_img_gen_params->ref_images [i]);
2398+ }
2399+
2400+ std::vector<uint8_t > empty_image_data;
2401+ sd_image_t empty_image = {(uint32_t )width, (uint32_t )height, 3 , nullptr };
2402+ if (ref_images.empty () && sd_version_is_unet_edit (sd_ctx->sd ->version )) {
2403+ LOG_WARN (" This model needs at least one reference image; using an empty reference" );
2404+ empty_image_data.resize (width * height * 3 );
2405+ ref_images.push_back (&empty_image);
2406+ empty_image.data = empty_image_data.data ();
2407+ guidance.img_cfg = 0 .f ;
2408+ }
2409+
2410+ if (ref_images.size () > 0 ) {
23912411 LOG_INFO (" EDIT mode" );
23922412 }
23932413
23942414 std::vector<ggml_tensor*> ref_latents;
2395- for (int i = 0 ; i < sd_img_gen_params-> ref_images_count ; i++) {
2415+ for (int i = 0 ; i < ref_images. size () ; i++) {
23962416 ggml_tensor* img = ggml_new_tensor_4d (work_ctx,
23972417 GGML_TYPE_F32,
2398- sd_img_gen_params-> ref_images [i]. width ,
2399- sd_img_gen_params-> ref_images [i]. height ,
2418+ ref_images[i]-> width ,
2419+ ref_images[i]-> height ,
24002420 3 ,
24012421 1 );
2402- sd_image_to_tensor (sd_img_gen_params-> ref_images [i], img);
2422+ sd_image_to_tensor (* ref_images[i], img);
24032423
24042424 ggml_tensor* latent = NULL ;
24052425 if (sd_ctx->sd ->use_tiny_autoencoder ) {
@@ -2437,7 +2457,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
24372457 SAFE_STR (sd_img_gen_params->prompt ),
24382458 SAFE_STR (sd_img_gen_params->negative_prompt ),
24392459 sd_img_gen_params->clip_skip ,
2440- sd_img_gen_params-> sample_params . guidance ,
2460+ guidance,
24412461 sd_img_gen_params->sample_params .eta ,
24422462 sd_img_gen_params->sample_params .shifted_timestep ,
24432463 width,
0 commit comments