Skip to content

Commit d6a2137

Browse files
Support Cosmos predict2 image to video models. (comfyanonymous#8535)
Use the CosmosPredict2ImageToVideoLatent node.
1 parent 53e8d81 commit d6a2137

File tree

3 files changed

+77
-5
lines changed

3 files changed

+77
-5
lines changed

comfy/model_base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,9 +1014,30 @@ def extra_conds(self, **kwargs):
10141014
if cross_attn is not None:
10151015
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
10161016

1017+
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
1018+
if denoise_mask is not None:
1019+
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
1020+
10171021
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
10181022
return out
10191023

1024+
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
1025+
if denoise_mask is None:
1026+
return timestep
1027+
condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True)
1028+
c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1
1029+
out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4])
1030+
return out
1031+
1032+
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
1033+
sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
1034+
sigma_noise_augmentation = 0 #TODO
1035+
if sigma_noise_augmentation != 0:
1036+
latent_image = latent_image + noise
1037+
latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
1038+
sigma = (sigma / (sigma + 1))
1039+
return latent_image / (1.0 - sigma)
1040+
10201041
class Lumina2(BaseModel):
10211042
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
10221043
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)

comfy/model_detection.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -441,11 +441,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
441441
dit_config["rope_h_extrapolation_ratio"] = 4.0
442442
dit_config["rope_w_extrapolation_ratio"] = 4.0
443443
dit_config["rope_t_extrapolation_ratio"] = 1.0
444-
elif dit_config["in_channels"] == 17:
445-
dit_config["extra_per_block_abs_pos_emb"] = False
446-
dit_config["rope_h_extrapolation_ratio"] = 3.0
447-
dit_config["rope_w_extrapolation_ratio"] = 3.0
448-
dit_config["rope_t_extrapolation_ratio"] = 1.0
444+
elif dit_config["in_channels"] == 17: # img to video
445+
if dit_config["model_channels"] == 2048:
446+
dit_config["extra_per_block_abs_pos_emb"] = False
447+
dit_config["rope_h_extrapolation_ratio"] = 3.0
448+
dit_config["rope_w_extrapolation_ratio"] = 3.0
449+
dit_config["rope_t_extrapolation_ratio"] = 1.0
450+
elif dit_config["model_channels"] == 5120:
451+
dit_config["rope_h_extrapolation_ratio"] = 2.0
452+
dit_config["rope_w_extrapolation_ratio"] = 2.0
453+
dit_config["rope_t_extrapolation_ratio"] = 0.8333333333333334
449454

450455
dit_config["extra_h_extrapolation_ratio"] = 1.0
451456
dit_config["extra_w_extrapolation_ratio"] = 1.0

comfy_extras/nodes_cosmos.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import comfy.model_management
44
import comfy.utils
5+
import comfy.latent_formats
56

67

78
class EmptyCosmosLatentVideo:
@@ -75,8 +76,53 @@ def encode(self, vae, width, height, length, batch_size, start_image=None, end_i
7576
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
7677
return (out_latent,)
7778

79+
class CosmosPredict2ImageToVideoLatent:
80+
@classmethod
81+
def INPUT_TYPES(s):
82+
return {"required": {"vae": ("VAE", ),
83+
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
84+
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
85+
"length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
86+
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
87+
},
88+
"optional": {"start_image": ("IMAGE", ),
89+
"end_image": ("IMAGE", ),
90+
}}
91+
92+
93+
RETURN_TYPES = ("LATENT",)
94+
FUNCTION = "encode"
95+
96+
CATEGORY = "conditioning/inpaint"
97+
98+
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
99+
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
100+
if start_image is None and end_image is None:
101+
out_latent = {}
102+
out_latent["samples"] = latent
103+
return (out_latent,)
104+
105+
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
106+
107+
if start_image is not None:
108+
latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
109+
latent[:, :, :latent_temp.shape[-3]] = latent_temp
110+
mask[:, :, :latent_temp.shape[-3]] *= 0.0
111+
112+
if end_image is not None:
113+
latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
114+
latent[:, :, -latent_temp.shape[-3]:] = latent_temp
115+
mask[:, :, -latent_temp.shape[-3]:] *= 0.0
116+
117+
out_latent = {}
118+
latent_format = comfy.latent_formats.Wan21()
119+
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
120+
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
121+
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
122+
return (out_latent,)
78123

79124
NODE_CLASS_MAPPINGS = {
80125
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
81126
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
127+
"CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent,
82128
}

0 commit comments

Comments
 (0)