Skip to content

Commit 29596bd

Browse files
Small cosmos attention code refactor. (comfyanonymous#8530)
1 parent 803af1e commit 29596bd

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

comfy/ldm/cosmos/predict2.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
7070
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
7171
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
7272
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
73-
result_B_S_HD = rearrange(
74-
optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, skip_output_reshape=True), "b h ... l -> b ... (h l)"
75-
)
76-
77-
return result_B_S_HD
73+
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
7874

7975

8076
class Attention(nn.Module):

0 commit comments

Comments
 (0)