文章目录
DiT 与 MM‑DiT:扩散模型中的 Transformer 架构
本文档详细介绍了两种基于 Transformer 的扩散模型主干网络:DiT (Diffusion Transformer) 及其多模态扩展 MM‑DiT (Multi‑Modal Diffusion Transformer)。重点阐述它们的核心思想、数学形式和模块设计,不包含具体实现代码。
1. DiT:纯 Transformer 替代 U‑Net
1.1 背景与动机
传统扩散模型(如 DDPM、Stable Diffusion)使用 U‑Net 作为去噪网络。U‑Net 的归纳偏置(局部卷积、下采样/上采样、跳跃连接)限制了模型的可扩展性和对长距离依赖的建模能力。
DiT 提出用 Vision Transformer (ViT) 完全替代 U‑Net,证明 Transformer 在扩散生成任务中同样高效且易于扩展。
- 为什么要用Transformer?
- 当模型参数量、数据量、计算量增加时,Transformer 的性能通常可以稳定提升。对 diffusion model 来说,扩大 Transformer backbone 的规模,也可以显著提升生成质量。
- 卷积天然有局部归纳偏置。虽然 U-Net 也可以通过 downsampling 获得大感受野,但它本质上仍然更偏局部。Transformer 的 self-attention 可以直接建模任意两个 patch 之间的关系, 适合全局建模
- Transformer 更统一,DiT 的结构可以自然扩展到多模态生成, 图片,视频,文本,音频都可以看做是Token.
1.2 模型结构
DiT 的输入是带噪图像的潜在表示
z
∈
R
C
×
H
×
W
\mathbf{z} \in \mathbb{R}^{C \times H \times W}
z∈RC×H×W,输出是预测的噪声
ϵ
^
\hat{\boldsymbol{\epsilon}}
ϵ^(或速度
v
^
\hat{\mathbf{v}}
v^)。其处理流程如下:

1.2.1 Patchify 与位置编码
将输入图像切分为大小为 p × p p \times p p×p 的 patches,展平为 token 序列:
x ∈ R N × D , N = H p ⋅ W p , D = p 2 ⋅ C \mathbf{x} \in \mathbb{R}^{N \times D}, \quad N = \frac{H}{p} \cdot \frac{W}{p}, \quad D = p^2 \cdot C x∈RN×D,N=pH⋅pW,D=p2⋅C
其中
D
D
D 是每个 patch 展平后的维度,随后通过线性投影映射到 Transformer 的隐藏维度
d
model
d_{\text{model}}
dmodel。
加入可学习或者二维的的位置编码
E
pos
∈
R
N
×
d
model
\mathbf{E}_{\text{pos}} \in \mathbb{R}^{N \times d_{\text{model}}}
Epos∈RN×dmodel:
x ← x + E pos \mathbf{x} \leftarrow \mathbf{x} + \mathbf{E}_{\text{pos}} x←x+Epos
1.2.2 条件注入:自适应层归一化 (AdaLN)
扩散模型需要输入时间步
t
t
t 和可选的类别标签
c
c
c(或其他条件)。DiT 采用 AdaLN 将条件信息注入每个 Transformer 块。
普通 LayerNorm 是:LN(x)
adaptive LayerNorm 是:adaLN(x, c) = scale© * LN(x) + shift©
首先,通过 MLP 从条件嵌入 c cond \mathbf{c}_{\text{cond}} ccond ( c cond = t e m b + y e m b \mathbf{c}_{\text{cond}} = t_{emb} + y_{emb} ccond=temb+yemb)回归出六组调制参数:
( γ 1 , β 1 , α 1 , γ 2 , β 2 , α 2 ) = MLP ( c cond ) (\gamma_1, \beta_1, \alpha_1, \gamma_2, \beta_2, \alpha_2) = \text{MLP}(\mathbf{c}_{\text{cond}}) (γ1,β1,α1,γ2,β2,α2)=MLP(ccond)
其中 γ , β , α ∈ R d model \gamma, \beta, \alpha \in \mathbb{R}^{d_{\text{model}}} γ,β,α∈Rdmodel。在自注意力和前馈网络之前,先对输入进行层归一化,再应用仿射变换:
AdaLN ( h , γ , β ) = γ ⊙ LayerNorm ( h ) + β \text{AdaLN}(h, \gamma, \beta) = \gamma \odot \text{LayerNorm}(h) + \beta AdaLN(h,γ,β)=γ⊙LayerNorm(h)+β
残差连接时使用 α \alpha α 进行缩放:
Output = Input + α ⊙ SubLayer ( AdaLN ( h , γ , β ) ) \text{Output} = \text{Input} + \alpha \odot \text{SubLayer}(\text{AdaLN}(h, \gamma, \beta)) Output=Input+α⊙SubLayer(AdaLN(h,γ,β))
条件嵌入 c cond \mathbf{c}_{\text{cond}} ccond 由时间步嵌入和类别嵌入求和得到:
c cond = MLP time ( t ) + MLP class ( c ) \mathbf{c}_{\text{cond}} = \text{MLP}_{\text{time}}(t) + \text{MLP}_{\text{class}}(c) ccond=MLPtime(t)+MLPclass(c)
DiT 论文中最重要的训练技巧之一是 adaLN-Zero。
它的思想是:
初始化时,让每个 Transformer block 对输入几乎没有影响,模型一开始接近 identity function。
具体做法是,把调制层最后一层初始化为 0,使得:
shift ≈ 0, scale ≈ 0, gate ≈ 0
因此一开始:
x = x + 0 * attention(…)
x = x + 0 * mlp(…)
也就是:
output ≈ input
这样子训练更稳定。
1.2.3 Transformer 块堆叠
每个 DiT 块包含多头自注意力和前馈网络,均使用 AdaLN 进行条件化。堆叠 L L L 个这样的块后,最后经过层归一化和一个线性层,将每个 token 映射回 patch 空间,并重组为图像。
1.2.4 输出层与训练目标
输出层将每个 token 的维度
d
model
d_{\text{model}}
dmodel 映射回
p
2
⋅
C
p^2 \cdot C
p2⋅C,然后通过 rearrange 操作恢复为
C
×
H
×
W
C \times H \times W
C×H×W 的噪声预测:
ϵ ^ = Linear ( LayerNorm ( x last ) ) \hat{\boldsymbol{\epsilon}} = \text{Linear}(\text{LayerNorm}(\mathbf{x}_{\text{last}})) ϵ^=Linear(LayerNorm(xlast))
训练时使用简单的均方误差损失:
L = E z 0 , ϵ , t [ ∥ ϵ − ϵ ^ ( z t , t , c ) ∥ 2 2 ] \mathcal{L} = \mathbb{E}_{\mathbf{z}_0, \boldsymbol{\epsilon}, t} \left[ \| \boldsymbol{\epsilon} - \hat{\boldsymbol{\epsilon}}(\mathbf{z}_t, t, c) \|_2^2 \right] L=Ez0,ϵ,t[∥ϵ−ϵ^(zt,t,c)∥22]
1.3 可扩展性实验
DiT 提供了四种模型尺寸:DiT‑S、DiT‑B、DiT‑L、DiT‑XL(参数量从 33M 到 675M)。实验表明,增大模型尺寸和 token 数量(减小 patch size)能稳定提升生成质量(FID 下降),在 ImageNet 256×256 上达到当时的 SOTA。
DiT 论文中常见命名类似:DiT-XL/2,DiT-L/4,DiT-B/8
这里:XL / L / B 表示模型规模:
- B Base
- L Large
- XL Extra Large
后面的 /2、/4、/8 表示 patch size。
代码
- DiT 的核心结构
import torch
import torch.nn as nn
class DiTBlock(nn.Module):
def __init__(self,hidden_size,num_heads):
super().__init__()
self.norm1=nn.LayerNorm(hidden_size,elementwise_affine=False)
self.attn=Attention(hidden_size,num_heads)
self.norm2=nn.LayerNorm(hidden_size,elementwise_affine=False)
self.mlp=MLP(hidden_size)
self.adaLN_modulation=nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size,6*hidden_size)
)
def forward(self,x,c):
shift_msa,scale_msa,gate_msa,\
shift_mlp,scale_mlp,gate_mlp= self.adaLN_modulation(c).chunk(6,dim=1)
x=x +gate_msa.unsqueeze(1)*self.attn(modulate(self.norm1(x),shift_msa,scale_msa))
x=x+game_mlp.unsqueeze(1)*self.mlp(modulate(self.norm2(x),shift_mlp,scale_mlp))
return x
- DiT 的Final Layer
class FinalLayer(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
self.linear = nn.Linear(
hidden_size,
patch_size * patch_size * out_channels
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm(x), shift, scale)
x = self.linear(x)
return x
- DiT Forward 流程
def forward(x,t,y):
# x: (B,C,H,W)
x=patch_embed(x)
x=x+pos_embed
t=timestep_embed(t)
y=lable_embed(y)
c=t+y
for block in blocks:
x=block(x,c)
x=final_layer(x,c)
x=unpatchify(x)
return x
- 训练目标
x0= batch["image_or_latent"]
noise=torch.randn_like(x0)
t=torch.randint(
0,
num_train_timesteps,
(x0.shape[0],),
device=x0.device
)
xt=noise_scheduler.add_noise(x0,noise,t)
noise_pred=dit(xt,t,condition)
loss=F.mse_loss(noise_pred,noise)
DiT里面的classifier-free guidance
DiT 也可以使用 CFG,和 U-Net diffusion 一样。
训练时会随机 drop condition:
if random.random() < drop_prob:
y = null_label
推理时分别预测:
eps_uncond = model(x_t, t, null_condition)
eps_cond = model(x_t, t, condition)
然后组合:
eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
含义是:
guidance_scale 越大,越听 condition;
guidance_scale 太大,图像可能过饱和、失真、多样性下降。
DiT 和Unet 的核心区别
- 输入表示不同
U-Net 输入是 feature map:B × C × H × W
DiT 输入是 token sequence:B × N × D
- 归纳偏置不同
U-Net 有强图像归纳偏置:
DiT 的归纳偏置弱一些,但可扩展性更强:
- 条件注入方式不同
U-Net 通常通过:
time embedding 加到 ResBlock
cross-attention 注入文本条件
2. MM‑DiT:多模态扩散 Transformer
2.1 动机
DiT 仅支持单模态输入(图像)和简单的类别条件。对于文本到图像生成,需要模型能同时理解文本和图像两种模态,并实现深层融合。
MM‑DiT 由 Stability AI 在 Stable Diffusion 3 中提出,它将 DiT 扩展为双流 Transformer,让文本和图像 token 在每一层进行双向交互。

2.2 输入表示
- 图像流:与 DiT 相同,将带噪图像的潜在表示 z \mathbf{z} z 切分为 patches,得到 token 序列 X I ∈ R N I × d model \mathbf{X}_I \in \mathbb{R}^{N_I \times d_{\text{model}}} XI∈RNI×dmodel。
假设 noisy latent 是:
x_t: [B, C, H, W]
经过 patching 后,会被切成 patch tokens:
patches: [B, N_img, patch_dim]
其中:
N_img = H / patch_size × W / patch_size
然后经过 Linear 投影到 Transformer hidden dimension:
x = Linear(patches) # [B, N_img, D]
再加上 positional embedding:
x = x + pos_embed
x 就是送入 MM-DiT blocks 的 image tokens
- 文本流:将文本 prompt 通过预训练编码器(如 CLIP 或 T5)编码为 token 特征,再线性投影到相同维度 d model d_{\text{model}} dmodel,得到 X T ∈ R N T × d model \mathbf{X}_T \in \mathbb{R}^{N_T \times d_{\text{model}}} XT∈RNT×dmodel.
文本分支:Caption 经过三个 text encoder: CLIP-G/14, CLIP-L/14, T5 XXL
CLIP 系列 encoder 通常提供两类信息:
- token-level embedding
- pooled embedding
这表示两个 CLIP encoder 各自产生一组 token embedding,典型长度是 77 个 token。
可以理解为:
clip_g_tokens: [B, 77, D_g]
clip_l_tokens: [B, 77, D_l]
文本 token 最终会被处理成较高维的 channel 表示,然后通过 Linear 投影到 MMDiT 内部使用的 hidden dimension。
T5 XXL 主要提供更强的自然语言理解能力,尤其是长文本、复杂语义、属性关系、空间关系等。图中 T5 XXL 输出也进入中间的 token embedding 路径。
text_tokens = concat([
clip_g_token_embeds,
clip_l_token_embeds,
t5_token_embeds
], dim=sequence_dim or feature_dim)
c = Linear(text_tokens)
c 可以理解为送入 MM-DiT blocks 的 context tokens / caption tokens。
CLIP 的 pooled output 通常表示整句话的全局语义,例如: 整体风格,主体类别,图像全局语义
pooled feature 经过 MLP 后,会和 timestep encoding 分支融合,得到 y
可以把 y 理解为: 全局调制条件 = 文本全局语义 + 当前 timestep 信息
高噪声阶段:模型更关注整体布局、语义和大结构
低噪声阶段:模型更关注纹理、边缘、文字、细节
-
时间步条件:
timestep 先通过 sinusoidal positional encoding 变成连续向量, 再经过 MLP, 然后与 pooled text condition 融合 -
总结
x:image tokens,来自 noised latent
c:caption / context tokens,来自 text encoders
y:global modulation vector,来自 pooled text + timestep
x 负责图像 latent token 的更新
c 负责文本 token 的更新
y 负责调制 attention / MLP / LayerNorm
2.3 MM‑DiT 块结构
每个 MM‑DiT 块同时处理两个流,包含三个子层,每个子层都使用 AdaLN 注入条件:
2.3.1 独立自注意力
对图像流和文本流分别应用自注意力,不跨流:
X
I
←
X
I
+
SelfAttn
(
AdaLN
(
X
I
,
γ
I
1
,
β
I
1
)
)
⋅
α
I
1
\mathbf{X}_I \leftarrow \mathbf{X}_I + \text{SelfAttn}(\text{AdaLN}(\mathbf{X}_I, \gamma_{I1}, \beta_{I1})) \cdot \alpha_{I1}
XI←XI+SelfAttn(AdaLN(XI,γI1,βI1))⋅αI1
X
T
←
X
T
+
SelfAttn
(
AdaLN
(
X
T
,
γ
T
1
,
β
T
1
)
)
⋅
α
T
1
\mathbf{X}_T \leftarrow \mathbf{X}_T + \text{SelfAttn}(\text{AdaLN}(\mathbf{X}_T, \gamma_{T1}, \beta_{T1})) \cdot \alpha_{T1}
XT←XT+SelfAttn(AdaLN(XT,γT1,βT1))⋅αT1
2.3.2 双向交叉注意力
这是 MM‑DiT 的核心创新。两个流互为键/值,实现信息的深度融合:
-
图像作为查询,文本作为键/值:
X I ← X I + CrossAttn ( AdaLN ( X I ) , X T , X T ) ⋅ α I 2 \mathbf{X}_I \leftarrow \mathbf{X}_I + \text{CrossAttn}(\text{AdaLN}(\mathbf{X}_I), \mathbf{X}_T, \mathbf{X}_T) \cdot \alpha_{I2} XI←XI+CrossAttn(AdaLN(XI),XT,XT)⋅αI2
-
文本作为查询,图像作为键/值:
X T ← X T + CrossAttn ( AdaLN ( X T ) , X I , X I ) ⋅ α T 2 \mathbf{X}_T \leftarrow \mathbf{X}_T + \text{CrossAttn}(\text{AdaLN}(\mathbf{X}_T), \mathbf{X}_I, \mathbf{X}_I) \cdot \alpha_{T2} XT←XT+CrossAttn(AdaLN(XT),XI,XI)⋅αT2
这种对称设计使得文本能够“看到”图像当前的状态,从而动态调整文本表示(例如,关注与图像相关的词语)。
2.3.3 独立前馈网络
最后,两个流各自经过一个前馈网络(FFN):
X
I
←
X
I
+
FFN
(
AdaLN
(
X
I
,
γ
I
3
,
β
I
3
)
)
⋅
α
I
3
\mathbf{X}_I \leftarrow \mathbf{X}_I + \text{FFN}(\text{AdaLN}(\mathbf{X}_I, \gamma_{I3}, \beta_{I3})) \cdot \alpha_{I3}
XI←XI+FFN(AdaLN(XI,γI3,βI3))⋅αI3
X
T
←
X
T
+
FFN
(
AdaLN
(
X
T
,
γ
T
3
,
β
T
3
)
)
⋅
α
T
3
\mathbf{X}_T \leftarrow \mathbf{X}_T + \text{FFN}(\text{AdaLN}(\mathbf{X}_T, \gamma_{T3}, \beta_{T3})) \cdot \alpha_{T3}
XT←XT+FFN(AdaLN(XT,γT3,βT3))⋅αT3
2.4 输出与训练
经过 L L L 个 MM‑DiT 块后,仅保留图像流的 token 序列 X I \mathbf{X}_I XI,通过线性层预测噪声(或速度):
ϵ ^ = Linear ( LayerNorm ( X I ) ) \hat{\boldsymbol{\epsilon}} = \text{Linear}(\text{LayerNorm}(\mathbf{X}_I)) ϵ^=Linear(LayerNorm(XI))
文本流的最终状态被丢弃,不作为输出。训练目标与 DiT 相同(预测噪声的 MSE 损失)。在推理时,可以使用 Classifier‑Free Guidance 对文本条件进行强度调节。
2.5 MM-DiT
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ============================================================
# 1. 工具函数:modulate
# ============================================================
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
AdaLN-style modulation.
Args:
x: [B, N, D]
shift: [B, D]
scale: [B, D]
Returns:
[B, N, D]
"""
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
# ============================================================
# 2. RMSNorm:可选的 Q/K norm
# ============================================================
class RMSNorm(nn.Module):
"""
Root Mean Square LayerNorm.
常用于 QK normalization,稳定 attention logits。
"""
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
self.elementwise_affine = elementwise_affine
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.register_parameter("weight", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [..., D]
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
x = x / rms
if self.weight is not None:
x = x * self.weight
return x
# ============================================================
# 3. MLP
# ============================================================
class MLP(nn.Module):
"""
Transformer FFN / MLP.
"""
def __init__(
self,
dim: int,
hidden_dim: Optional[int] = None,
dropout: float = 0.0,
):
super().__init__()
hidden_dim = hidden_dim or dim * 4
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(approximate="tanh"),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
# ============================================================
# 4. Multi-Head Attention:接受已经算好的 Q/K/V
# ============================================================
class JointAttention(nn.Module):
"""
Joint Attention for MMDiT.
输入:
q_c, k_c, v_c: text tokens 的 Q/K/V
q_x, k_x, v_x: image tokens 的 Q/K/V
做法:
1. concat text/image Q/K/V
2. scaled dot-product attention
3. split 回 text/image 两部分
"""
def __init__(
self,
dim: int,
num_heads: int,
qk_norm: bool = True,
attn_dropout: float = 0.0,
proj_dropout: float = 0.0,
):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.attn_dropout = attn_dropout
self.proj_dropout = nn.Dropout(proj_dropout)
if qk_norm:
self.q_norm_c = RMSNorm(self.head_dim)
self.k_norm_c = RMSNorm(self.head_dim)
self.q_norm_x = RMSNorm(self.head_dim)
self.k_norm_x = RMSNorm(self.head_dim)
else:
self.q_norm_c = nn.Identity()
self.k_norm_c = nn.Identity()
self.q_norm_x = nn.Identity()
self.k_norm_x = nn.Identity()
def _reshape_heads(self, t: torch.Tensor) -> torch.Tensor:
"""
[B, N, D] -> [B, H, N, Dh]
"""
B, N, D = t.shape
t = t.view(B, N, self.num_heads, self.head_dim)
t = t.transpose(1, 2)
return t
def _merge_heads(self, t: torch.Tensor) -> torch.Tensor:
"""
[B, H, N, Dh] -> [B, N, D]
"""
B, H, N, Dh = t.shape
t = t.transpose(1, 2).contiguous()
t = t.view(B, N, H * Dh)
return t
def forward(
self,
q_c: torch.Tensor,
k_c: torch.Tensor,
v_c: torch.Tensor,
q_x: torch.Tensor,
k_x: torch.Tensor,
v_x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
q_c, k_c, v_c: [B, N_text, D]
q_x, k_x, v_x: [B, N_image, D]
attn_mask:
可选 attention mask。
如果使用 PyTorch SDPA,推荐形状可广播到 [B, H, N_total, N_total]。
Returns:
out_c: [B, N_text, D]
out_x: [B, N_image, D]
"""
B, N_text, D = q_c.shape
N_image = q_x.shape[1]
# [B, N, D] -> [B, H, N, Dh]
q_c = self._reshape_heads(q_c)
k_c = self._reshape_heads(k_c)
v_c = self._reshape_heads(v_c)
q_x = self._reshape_heads(q_x)
k_x = self._reshape_heads(k_x)
v_x = self._reshape_heads(v_x)
# Optional QK norm
q_c = self.q_norm_c(q_c)
k_c = self.k_norm_c(k_c)
q_x = self.q_norm_x(q_x)
k_x = self.k_norm_x(k_x)
# Joint attention:
# [text tokens, image tokens]
q = torch.cat([q_c, q_x], dim=2) # [B, H, N_text + N_image, Dh]
k = torch.cat([k_c, k_x], dim=2)
v = torch.cat([v_c, v_x], dim=2)
# PyTorch 2.0 scaled_dot_product_attention
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=self.attn_dropout if self.training else 0.0,
is_causal=False,
)
# split back
out_c = out[:, :, :N_text, :]
out_x = out[:, :, N_text:N_text + N_image, :]
out_c = self._merge_heads(out_c) # [B, N_text, D]
out_x = self._merge_heads(out_x) # [B, N_image, D]
out_c = self.proj_dropout(out_c)
out_x = self.proj_dropout(out_x)
return out_c, out_x
# ============================================================
# 5. MM-DiT Block
# ============================================================
class MMDiTBlock(nn.Module):
"""
一个完整的 MM-DiT Block。
输入:
c: text / caption tokens, [B, N_text, D]
x: image / latent tokens, [B, N_image, D]
y: global condition, [B, D]
通常来自 timestep embedding + pooled text embedding
输出:
c: updated text tokens
x: updated image tokens
核心结构:
text branch:
LN -> AdaLN modulation -> QKV
image branch:
LN -> AdaLN modulation -> QKV
concat Q/K/V -> joint attention -> split
text branch:
output proj -> gate -> residual
LN -> AdaLN modulation -> MLP -> gate -> residual
image branch:
output proj -> gate -> residual
LN -> AdaLN modulation -> MLP -> gate -> residual
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_norm: bool = True,
dropout: float = 0.0,
attn_dropout: float = 0.0,
norm_eps: float = 1e-6,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
mlp_hidden_dim = int(dim * mlp_ratio)
# ----------------------------------------------------
# y -> modulation parameters
# 每个分支 6 组参数:
# shift_msa, scale_msa, gate_msa,
# shift_mlp, scale_mlp, gate_mlp
# text 和 image 分支各自一套
# ----------------------------------------------------
self.adaLN_modulation_c = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim, bias=True),
)
self.adaLN_modulation_x = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim, bias=True),
)
# ----------------------------------------------------
# Attention pre-norm
# ----------------------------------------------------
self.norm1_c = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=False)
self.norm1_x = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=False)
# ----------------------------------------------------
# Separate QKV projections
# text 和 image 不共享 QKV 参数
# ----------------------------------------------------
self.qkv_c = nn.Linear(dim, 3 * dim, bias=qkv_bias)
self.qkv_x = nn.Linear(dim, 3 * dim, bias=qkv_bias)
self.joint_attn = JointAttention(
dim=dim,
num_heads=num_heads,
qk_norm=qk_norm,
attn_dropout=attn_dropout,
proj_dropout=dropout,
)
# ----------------------------------------------------
# Separate output projections
# ----------------------------------------------------
self.proj_c = nn.Linear(dim, dim)
self.proj_x = nn.Linear(dim, dim)
self.proj_drop_c = nn.Dropout(dropout)
self.proj_drop_x = nn.Dropout(dropout)
# ----------------------------------------------------
# MLP sub-layer
# ----------------------------------------------------
self.norm2_c = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=False)
self.norm2_x = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=False)
self.mlp_c = MLP(dim, hidden_dim=mlp_hidden_dim, dropout=dropout)
self.mlp_x = MLP(dim, hidden_dim=mlp_hidden_dim, dropout=dropout)
self.initialize_weights()
def initialize_weights(self):
"""
简化版初始化。
注意:为了模拟 AdaLN-Zero 的稳定训练特点,
modulation 最后一层可以 zero init,
这样 block 初始时近似 residual identity。
"""
nn.init.zeros_(self.adaLN_modulation_c[-1].weight)
nn.init.zeros_(self.adaLN_modulation_c[-1].bias)
nn.init.zeros_(self.adaLN_modulation_x[-1].weight)
nn.init.zeros_(self.adaLN_modulation_x[-1].bias)
def forward(
self,
c: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
c: [B, N_text, D]
x: [B, N_image, D]
y: [B, D]
attn_mask: optional attention mask
Returns:
c: [B, N_text, D]
x: [B, N_image, D]
"""
# ====================================================
# 1. 由 y 生成 text/image 两套 modulation 参数
# ====================================================
(
shift_msa_c,
scale_msa_c,
gate_msa_c,
shift_mlp_c,
scale_mlp_c,
gate_mlp_c,
) = self.adaLN_modulation_c(y).chunk(6, dim=-1)
(
shift_msa_x,
scale_msa_x,
gate_msa_x,
shift_mlp_x,
scale_mlp_x,
gate_mlp_x,
) = self.adaLN_modulation_x(y).chunk(6, dim=-1)
# ====================================================
# 2. Attention 子层:LN + Modulation
# ====================================================
c_attn_in = modulate(self.norm1_c(c), shift_msa_c, scale_msa_c)
x_attn_in = modulate(self.norm1_x(x), shift_msa_x, scale_msa_x)
# ====================================================
# 3. Separate QKV
# ====================================================
q_c, k_c, v_c = self.qkv_c(c_attn_in).chunk(3, dim=-1)
q_x, k_x, v_x = self.qkv_x(x_attn_in).chunk(3, dim=-1)
# ====================================================
# 4. Joint Attention
# ====================================================
out_c, out_x = self.joint_attn(
q_c=q_c,
k_c=k_c,
v_c=v_c,
q_x=q_x,
k_x=k_x,
v_x=v_x,
attn_mask=attn_mask,
)
# ====================================================
# 5. Attention output projection + gate + residual
# ====================================================
c = c + gate_msa_c.unsqueeze(1) * self.proj_drop_c(self.proj_c(out_c))
x = x + gate_msa_x.unsqueeze(1) * self.proj_drop_x(self.proj_x(out_x))
# ====================================================
# 6. MLP 子层:LN + Modulation + MLP + gate + residual
# ====================================================
c_mlp_in = modulate(self.norm2_c(c), shift_mlp_c, scale_mlp_c)
x_mlp_in = modulate(self.norm2_x(x), shift_mlp_x, scale_mlp_x)
c = c + gate_mlp_c.unsqueeze(1) * self.mlp_c(c_mlp_in)
x = x + gate_mlp_x.unsqueeze(1) * self.mlp_x(x_mlp_in)
return c, x
# ============================================================
# 6. 简单测试
# ============================================================
if __name__ == "__main__":
torch.manual_seed(0)
B = 2
N_text = 154
N_image = 1024
D = 1152
num_heads = 16
c = torch.randn(B, N_text, D)
x = torch.randn(B, N_image, D)
y = torch.randn(B, D)
block = MMDiTBlock(
dim=D,
num_heads=num_heads,
mlp_ratio=4.0,
qkv_bias=True,
qk_norm=True,
dropout=0.0,
attn_dropout=0.0,
)
c_out, x_out = block(c, x, y)
print("c_out:", c_out.shape)
print("x_out:", x_out.shape)
assert c_out.shape == c.shape
assert x_out.shape == x.shape
2.6 关键优势
- 深度双向融合:不同于传统 cross‑attention(图像单向看文本),MM‑DiT 让文本也能关注图像,使得文本特征在生成过程中不断优化,尤其有助于处理复杂提示(如空间关系、多主体、属性绑定)。
- 长文本支持:自注意力机制天然支持变长序列,可处理数百甚至上千个文本 token,提升文字渲染能力。
- 可扩展性:与 DiT 类似,MM‑DiT 的参数量可从 800M 扩展到 8B,性能持续提升。
3. 总结对比
| 特性 | DiT | MM‑DiT |
|---|---|---|
| 输入模态 | 单模态(图像) + 可选条件(类别) | 多模态(图像 + 文本) |
| 条件融合 | AdaLN(将条件调制到层归一化参数) | AdaLN + 双向交叉注意力 |
| 自注意力 | 仅图像 token | 图像 token 和文本 token 各自独立自注意力 |
| 跨模态交互 | 无(条件仅通过 AdaLN 注入) | 每一层都进行图像↔文本的双向交叉注意力 |
| 主要应用 | 无条件/类别条件图像生成 | 文本到图像生成,后续可扩展为视频生成 |
| 代表工作 | DiT (2023) | Stable Diffusion 3 (2024) |
两者共同推动了扩散模型向 Transformer 统一架构 的演进,DiT 奠定了技术基础,MM‑DiT 则将其成功应用于工业级多模态生成场景。
参考文献
- Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023.
- Esser, P., et al. (2024). Scaling Rectified Flow Transformers for High-Resolution Image Synthesis. *arXiv:2403.03206 *.
509

被折叠的 条评论
为什么被折叠?



