@@ -360,6 +360,34 @@ def prepare_infinite_you(self, id_image, controlnet_image, infinityou_guidance,
360
360
return self .infinityou_processor .prepare_infinite_you (self .image_proj_model , id_image , controlnet_image , infinityou_guidance , height , width )
361
361
else :
362
362
return {}, controlnet_image
363
+
364
+
365
+ def prepare_flex_kwargs (self , latents , flex_inpaint_image = None , flex_inpaint_mask = None , flex_control_image = None , flex_control_strength = 0.5 , flex_control_stop = 0.5 , tiled = False , tile_size = 64 , tile_stride = 32 ):
366
+ if self .dit .input_dim == 196 :
367
+ if flex_inpaint_image is None :
368
+ flex_inpaint_image = torch .zeros_like (latents )
369
+ else :
370
+ flex_inpaint_image = self .preprocess_image (flex_inpaint_image ).to (device = self .device , dtype = self .torch_dtype )
371
+ flex_inpaint_image = self .encode_image (flex_inpaint_image , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride )
372
+ if flex_inpaint_mask is None :
373
+ flex_inpaint_mask = torch .ones_like (latents )[:, 0 :1 , :, :]
374
+ else :
375
+ flex_inpaint_mask = flex_inpaint_mask .resize ((latents .shape [3 ], latents .shape [2 ]))
376
+ flex_inpaint_mask = self .preprocess_image (flex_inpaint_mask ).to (device = self .device , dtype = self .torch_dtype )
377
+ flex_inpaint_mask = (flex_inpaint_mask [:, 0 :1 , :, :] + 1 ) / 2
378
+ flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask )
379
+ if flex_control_image is None :
380
+ flex_control_image = torch .zeros_like (latents )
381
+ else :
382
+ flex_control_image = self .preprocess_image (flex_control_image ).to (device = self .device , dtype = self .torch_dtype )
383
+ flex_control_image = self .encode_image (flex_control_image , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride ) * flex_control_strength
384
+ flex_condition = torch .concat ([flex_inpaint_image , flex_inpaint_mask , flex_control_image ], dim = 1 )
385
+ flex_uncondition = torch .concat ([flex_inpaint_image , flex_inpaint_mask , torch .zeros_like (flex_control_image )], dim = 1 )
386
+ flex_control_stop_timestep = self .scheduler .timesteps [int (flex_control_stop * (len (self .scheduler .timesteps ) - 1 ))]
387
+ flex_kwargs = {"flex_condition" : flex_condition , "flex_uncondition" : flex_uncondition , "flex_control_stop_timestep" : flex_control_stop_timestep }
388
+ else :
389
+ flex_kwargs = {}
390
+ return flex_kwargs
363
391
364
392
365
393
@torch .no_grad ()
@@ -398,6 +426,12 @@ def __call__(
398
426
# InfiniteYou
399
427
infinityou_id_image = None ,
400
428
infinityou_guidance = 1.0 ,
429
+ # Flex
430
+ flex_inpaint_image = None ,
431
+ flex_inpaint_mask = None ,
432
+ flex_control_image = None ,
433
+ flex_control_strength = 0.5 ,
434
+ flex_control_stop = 0.5 ,
401
435
# TeaCache
402
436
tea_cache_l1_thresh = None ,
403
437
# Tile
@@ -436,6 +470,9 @@ def __call__(
436
470
437
471
# ControlNets
438
472
controlnet_kwargs_posi , controlnet_kwargs_nega , local_controlnet_kwargs = self .prepare_controlnet (controlnet_image , masks , controlnet_inpaint_mask , tiler_kwargs , enable_controlnet_on_negative )
473
+
474
+ # Flex
475
+ flex_kwargs = self .prepare_flex_kwargs (latents , flex_inpaint_image , flex_inpaint_mask , flex_control_image , ** tiler_kwargs )
439
476
440
477
# TeaCache
441
478
tea_cache_kwargs = {"tea_cache" : TeaCache (num_inference_steps , rel_l1_thresh = tea_cache_l1_thresh ) if tea_cache_l1_thresh is not None else None }
@@ -449,7 +486,7 @@ def __call__(
449
486
inference_callback = lambda prompt_emb_posi , controlnet_kwargs : lets_dance_flux (
450
487
dit = self .dit , controlnet = self .controlnet ,
451
488
hidden_states = latents , timestep = timestep ,
452
- ** prompt_emb_posi , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs , ** ipadapter_kwargs_list_posi , ** eligen_kwargs_posi , ** tea_cache_kwargs , ** infiniteyou_kwargs
489
+ ** prompt_emb_posi , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs , ** ipadapter_kwargs_list_posi , ** eligen_kwargs_posi , ** tea_cache_kwargs , ** infiniteyou_kwargs , ** flex_kwargs ,
453
490
)
454
491
noise_pred_posi = self .control_noise_via_local_prompts (
455
492
prompt_emb_posi , prompt_emb_locals , masks , mask_scales , inference_callback ,
@@ -466,7 +503,7 @@ def __call__(
466
503
noise_pred_nega = lets_dance_flux (
467
504
dit = self .dit , controlnet = self .controlnet ,
468
505
hidden_states = latents , timestep = timestep ,
469
- ** prompt_emb_nega , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs_nega , ** ipadapter_kwargs_list_nega , ** eligen_kwargs_nega , ** infiniteyou_kwargs ,
506
+ ** prompt_emb_nega , ** tiler_kwargs , ** extra_input , ** controlnet_kwargs_nega , ** ipadapter_kwargs_list_nega , ** eligen_kwargs_nega , ** infiniteyou_kwargs , ** flex_kwargs ,
470
507
)
471
508
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega )
472
509
else :
@@ -602,6 +639,9 @@ def lets_dance_flux(
602
639
ipadapter_kwargs_list = {},
603
640
id_emb = None ,
604
641
infinityou_guidance = None ,
642
+ flex_condition = None ,
643
+ flex_uncondition = None ,
644
+ flex_control_stop_timestep = None ,
605
645
tea_cache : TeaCache = None ,
606
646
** kwargs
607
647
):
@@ -652,6 +692,13 @@ def flux_forward_fn(hl, hr, wl, wr):
652
692
controlnet_res_stack , controlnet_single_res_stack = controlnet (
653
693
controlnet_frames , ** controlnet_extra_kwargs
654
694
)
695
+
696
+ # Flex
697
+ if flex_condition is not None :
698
+ if timestep .tolist ()[0 ] >= flex_control_stop_timestep :
699
+ hidden_states = torch .concat ([hidden_states , flex_condition ], dim = 1 )
700
+ else :
701
+ hidden_states = torch .concat ([hidden_states , flex_uncondition ], dim = 1 )
655
702
656
703
if image_ids is None :
657
704
image_ids = dit .prepare_image_ids (hidden_states )
0 commit comments