Skip to content

Commit 4f01b37

Browse files
authored
Merge pull request #553 from modelscope/flex
Flex
2 parents e54c0a8 + cc63061 commit 4f01b37

File tree

4 files changed

+106
-5
lines changed

4 files changed

+106
-5
lines changed

diffsynth/configs/model_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
9999
(None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"),
100100
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
101+
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
101102
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
102103
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
103104
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),

diffsynth/models/flux_dit.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -276,20 +276,22 @@ def forward(self, x, conditioning):
276276

277277

278278
class FluxDiT(torch.nn.Module):
279-
def __init__(self, disable_guidance_embedder=False):
279+
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
280280
super().__init__()
281281
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
282282
self.time_embedder = TimestepEmbeddings(256, 3072)
283283
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
284284
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
285285
self.context_embedder = torch.nn.Linear(4096, 3072)
286-
self.x_embedder = torch.nn.Linear(64, 3072)
286+
self.x_embedder = torch.nn.Linear(input_dim, 3072)
287287

288-
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
288+
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
289289
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
290290

291291
self.final_norm_out = AdaLayerNormContinuous(3072)
292292
self.final_proj_out = torch.nn.Linear(3072, 64)
293+
294+
self.input_dim = input_dim
293295

294296

295297
def patchify(self, hidden_states):
@@ -738,5 +740,7 @@ def from_civitai(self, state_dict):
738740
pass
739741
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
740742
return state_dict_, {"disable_guidance_embedder": True}
743+
elif "blocks.8.attn.norm_k_a.weight" not in state_dict_:
744+
return state_dict_, {"input_dim": 196, "num_blocks": 8}
741745
else:
742746
return state_dict_

diffsynth/pipelines/flux_image.py

+49-2
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,34 @@ def prepare_infinite_you(self, id_image, controlnet_image, infinityou_guidance,
360360
return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width)
361361
else:
362362
return {}, controlnet_image
363+
364+
365+
def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, flex_control_strength=0.5, flex_control_stop=0.5, tiled=False, tile_size=64, tile_stride=32):
366+
if self.dit.input_dim == 196:
367+
if flex_inpaint_image is None:
368+
flex_inpaint_image = torch.zeros_like(latents)
369+
else:
370+
flex_inpaint_image = self.preprocess_image(flex_inpaint_image).to(device=self.device, dtype=self.torch_dtype)
371+
flex_inpaint_image = self.encode_image(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
372+
if flex_inpaint_mask is None:
373+
flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :]
374+
else:
375+
flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2]))
376+
flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype)
377+
flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2
378+
flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask)
379+
if flex_control_image is None:
380+
flex_control_image = torch.zeros_like(latents)
381+
else:
382+
flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype)
383+
flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength
384+
flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1)
385+
flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1)
386+
flex_control_stop_timestep = self.scheduler.timesteps[int(flex_control_stop * (len(self.scheduler.timesteps) - 1))]
387+
flex_kwargs = {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep}
388+
else:
389+
flex_kwargs = {}
390+
return flex_kwargs
363391

364392

365393
@torch.no_grad()
@@ -398,6 +426,12 @@ def __call__(
398426
# InfiniteYou
399427
infinityou_id_image=None,
400428
infinityou_guidance=1.0,
429+
# Flex
430+
flex_inpaint_image=None,
431+
flex_inpaint_mask=None,
432+
flex_control_image=None,
433+
flex_control_strength=0.5,
434+
flex_control_stop=0.5,
401435
# TeaCache
402436
tea_cache_l1_thresh=None,
403437
# Tile
@@ -436,6 +470,9 @@ def __call__(
436470

437471
# ControlNets
438472
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
473+
474+
# Flex
475+
flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, **tiler_kwargs)
439476

440477
# TeaCache
441478
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
@@ -449,7 +486,7 @@ def __call__(
449486
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
450487
dit=self.dit, controlnet=self.controlnet,
451488
hidden_states=latents, timestep=timestep,
452-
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs
489+
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs,
453490
)
454491
noise_pred_posi = self.control_noise_via_local_prompts(
455492
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
@@ -466,7 +503,7 @@ def __call__(
466503
noise_pred_nega = lets_dance_flux(
467504
dit=self.dit, controlnet=self.controlnet,
468505
hidden_states=latents, timestep=timestep,
469-
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs,
506+
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs,
470507
)
471508
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
472509
else:
@@ -602,6 +639,9 @@ def lets_dance_flux(
602639
ipadapter_kwargs_list={},
603640
id_emb=None,
604641
infinityou_guidance=None,
642+
flex_condition=None,
643+
flex_uncondition=None,
644+
flex_control_stop_timestep=None,
605645
tea_cache: TeaCache = None,
606646
**kwargs
607647
):
@@ -652,6 +692,13 @@ def flux_forward_fn(hl, hr, wl, wr):
652692
controlnet_res_stack, controlnet_single_res_stack = controlnet(
653693
controlnet_frames, **controlnet_extra_kwargs
654694
)
695+
696+
# Flex
697+
if flex_condition is not None:
698+
if timestep.tolist()[0] >= flex_control_stop_timestep:
699+
hidden_states = torch.concat([hidden_states, flex_condition], dim=1)
700+
else:
701+
hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1)
655702

656703
if image_ids is None:
657704
image_ids = dit.prepare_image_ids(hidden_states)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
from diffsynth import ModelManager, FluxImagePipeline, download_models
3+
from diffsynth.controlnets.processors import Annotator
4+
import numpy as np
5+
from PIL import Image
6+
7+
8+
download_models(["FLUX.1-dev"])
9+
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
10+
model_manager.load_models([
11+
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
12+
"models/FLUX/FLUX.1-dev/text_encoder_2",
13+
"models/FLUX/FLUX.1-dev/ae.safetensors",
14+
"models/ostris/Flex.2-preview/Flex.2-preview.safetensors"
15+
])
16+
pipe = FluxImagePipeline.from_model_manager(model_manager)
17+
18+
image = pipe(
19+
prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach",
20+
num_inference_steps=50, embedded_guidance=3.5,
21+
seed=0
22+
)
23+
image.save("image_1.jpg")
24+
25+
mask = np.zeros((1024, 1024, 3), dtype=np.uint8)
26+
mask[200:400, 400:700] = 255
27+
mask = Image.fromarray(mask)
28+
mask.save("image_mask.jpg")
29+
30+
inpaint_image = image
31+
32+
image = pipe(
33+
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach",
34+
num_inference_steps=50, embedded_guidance=3.5,
35+
flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask,
36+
seed=4
37+
)
38+
image.save("image_2.jpg")
39+
40+
control_image = Annotator("canny")(image)
41+
control_image.save("image_control.jpg")
42+
43+
image = pipe(
44+
prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach",
45+
num_inference_steps=50, embedded_guidance=3.5,
46+
flex_control_image=control_image,
47+
seed=4
48+
)
49+
image.save("image_3.jpg")

0 commit comments

Comments
 (0)