DiT和MM-DiT详解

DiT 与 MM‑DiT:扩散模型中的 Transformer 架构

本文档详细介绍了两种基于 Transformer 的扩散模型主干网络:DiT (Diffusion Transformer) 及其多模态扩展 MM‑DiT (Multi‑Modal Diffusion Transformer)。重点阐述它们的核心思想、数学形式和模块设计,不包含具体实现代码。


1. DiT:纯 Transformer 替代 U‑Net

1.1 背景与动机

传统扩散模型(如 DDPM、Stable Diffusion)使用 U‑Net 作为去噪网络。U‑Net 的归纳偏置(局部卷积、下采样/上采样、跳跃连接)限制了模型的可扩展性和对长距离依赖的建模能力。
DiT 提出用 Vision Transformer (ViT) 完全替代 U‑Net,证明 Transformer 在扩散生成任务中同样高效且易于扩展。

  • 为什么要用Transformer?
  1. 当模型参数量、数据量、计算量增加时,Transformer 的性能通常可以稳定提升。对 diffusion model 来说,扩大 Transformer backbone 的规模,也可以显著提升生成质量。
  2. 卷积天然有局部归纳偏置。虽然 U-Net 也可以通过 downsampling 获得大感受野,但它本质上仍然更偏局部。Transformer 的 self-attention 可以直接建模任意两个 patch 之间的关系, 适合全局建模
  3. Transformer 更统一,DiT 的结构可以自然扩展到多模态生成, 图片,视频,文本,音频都可以看做是Token.

1.2 模型结构

DiT 的输入是带噪图像的潜在表示 z ∈ R C × H × W \mathbf{z} \in \mathbb{R}^{C \times H \times W} zRC×H×W,输出是预测的噪声 ϵ ^ \hat{\boldsymbol{\epsilon}} ϵ^(或速度 v ^ \hat{\mathbf{v}} v^)。其处理流程如下:
在这里插入图片描述

1.2.1 Patchify 与位置编码

将输入图像切分为大小为 p × p p \times p p×p 的 patches,展平为 token 序列:

x ∈ R N × D , N = H p ⋅ W p , D = p 2 ⋅ C \mathbf{x} \in \mathbb{R}^{N \times D}, \quad N = \frac{H}{p} \cdot \frac{W}{p}, \quad D = p^2 \cdot C xRN×D,N=pHpW,D=p2C

其中 D D D 是每个 patch 展平后的维度,随后通过线性投影映射到 Transformer 的隐藏维度 d model d_{\text{model}} dmodel
加入可学习或者二维的的位置编码 E pos ∈ R N × d model \mathbf{E}_{\text{pos}} \in \mathbb{R}^{N \times d_{\text{model}}} EposRN×dmodel

x ← x + E pos \mathbf{x} \leftarrow \mathbf{x} + \mathbf{E}_{\text{pos}} xx+Epos

1.2.2 条件注入:自适应层归一化 (AdaLN)

扩散模型需要输入时间步 t t t 和可选的类别标签 c c c(或其他条件)。DiT 采用 AdaLN 将条件信息注入每个 Transformer 块。
普通 LayerNorm 是:LN(x)

adaptive LayerNorm 是:adaLN(x, c) = scale© * LN(x) + shift©

首先,通过 MLP 从条件嵌入 c cond \mathbf{c}_{\text{cond}} ccond c cond = t e m b + y e m b \mathbf{c}_{\text{cond}} = t_{emb} + y_{emb} ccond=temb+yemb)回归出六组调制参数:

( γ 1 , β 1 , α 1 , γ 2 , β 2 , α 2 ) = MLP ( c cond ) (\gamma_1, \beta_1, \alpha_1, \gamma_2, \beta_2, \alpha_2) = \text{MLP}(\mathbf{c}_{\text{cond}}) (γ1,β1,α1,γ2,β2,α2)=MLP(ccond)

其中 γ , β , α ∈ R d model \gamma, \beta, \alpha \in \mathbb{R}^{d_{\text{model}}} γ,β,αRdmodel。在自注意力和前馈网络之前,先对输入进行层归一化,再应用仿射变换:

AdaLN ( h , γ , β ) = γ ⊙ LayerNorm ( h ) + β \text{AdaLN}(h, \gamma, \beta) = \gamma \odot \text{LayerNorm}(h) + \beta AdaLN(h,γ,β)=γLayerNorm(h)+β

残差连接时使用 α \alpha α 进行缩放:

Output = Input + α ⊙ SubLayer ( AdaLN ( h , γ , β ) ) \text{Output} = \text{Input} + \alpha \odot \text{SubLayer}(\text{AdaLN}(h, \gamma, \beta)) Output=Input+αSubLayer(AdaLN(h,γ,β))

条件嵌入 c cond \mathbf{c}_{\text{cond}} ccond 由时间步嵌入和类别嵌入求和得到:

c cond = MLP time ( t ) + MLP class ( c ) \mathbf{c}_{\text{cond}} = \text{MLP}_{\text{time}}(t) + \text{MLP}_{\text{class}}(c) ccond=MLPtime(t)+MLPclass(c)

DiT 论文中最重要的训练技巧之一是 adaLN-Zero

它的思想是:

初始化时,让每个 Transformer block 对输入几乎没有影响,模型一开始接近 identity function。
具体做法是,把调制层最后一层初始化为 0,使得:

shift ≈ 0, scale ≈ 0, gate ≈ 0

因此一开始:

x = x + 0 * attention(…)
x = x + 0 * mlp(…)

也就是:

output ≈ input
这样子训练更稳定。

1.2.3 Transformer 块堆叠

每个 DiT 块包含多头自注意力和前馈网络,均使用 AdaLN 进行条件化。堆叠 L L L 个这样的块后,最后经过层归一化和一个线性层,将每个 token 映射回 patch 空间,并重组为图像。

1.2.4 输出层与训练目标

输出层将每个 token 的维度 d model d_{\text{model}} dmodel 映射回 p 2 ⋅ C p^2 \cdot C p2C,然后通过 rearrange 操作恢复为 C × H × W C \times H \times W C×H×W 的噪声预测:

ϵ ^ = Linear ( LayerNorm ( x last ) ) \hat{\boldsymbol{\epsilon}} = \text{Linear}(\text{LayerNorm}(\mathbf{x}_{\text{last}})) ϵ^=Linear(LayerNorm(xlast))

训练时使用简单的均方误差损失:

L = E z 0 , ϵ , t [ ∥ ϵ − ϵ ^ ( z t , t , c ) ∥ 2 2 ] \mathcal{L} = \mathbb{E}_{\mathbf{z}_0, \boldsymbol{\epsilon}, t} \left[ \| \boldsymbol{\epsilon} - \hat{\boldsymbol{\epsilon}}(\mathbf{z}_t, t, c) \|_2^2 \right] L=Ez0,ϵ,t[ϵϵ^(zt,t,c)22]

1.3 可扩展性实验

DiT 提供了四种模型尺寸:DiT‑SDiT‑BDiT‑LDiT‑XL(参数量从 33M 到 675M)。实验表明,增大模型尺寸和 token 数量(减小 patch size)能稳定提升生成质量(FID 下降),在 ImageNet 256×256 上达到当时的 SOTA。

DiT 论文中常见命名类似:DiT-XL/2,DiT-L/4,DiT-B/8

这里:XL / L / B 表示模型规模:

  • B Base
  • L Large
  • XL Extra Large

后面的 /2、/4、/8 表示 patch size。

代码

  • DiT 的核心结构
import torch 
import torch.nn as nn
class DiTBlock(nn.Module):
	def __init__(self,hidden_size,num_heads):
		super().__init__()
		self.norm1=nn.LayerNorm(hidden_size,elementwise_affine=False)
		self.attn=Attention(hidden_size,num_heads)
		self.norm2=nn.LayerNorm(hidden_size,elementwise_affine=False)
		self.mlp=MLP(hidden_size)
		self.adaLN_modulation=nn.Sequential(
			nn.SiLU(),
			nn.Linear(hidden_size,6*hidden_size)
		)
	def forward(self,x,c):
		shift_msa,scale_msa,gate_msa,\
		shift_mlp,scale_mlp,gate_mlp= self.adaLN_modulation(c).chunk(6,dim=1)
		x=x +gate_msa.unsqueeze(1)*self.attn(modulate(self.norm1(x),shift_msa,scale_msa))
		x=x+game_mlp.unsqueeze(1)*self.mlp(modulate(self.norm2(x),shift_mlp,scale_mlp))
		return x
  • DiT 的Final Layer
class FinalLayer(nn.Module):
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.linear = nn.Linear(
            hidden_size,
            patch_size * patch_size * out_channels
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm(x), shift, scale)
        x = self.linear(x)
        return x
  • DiT Forward 流程
def forward(x,t,y):
	# x: (B,C,H,W)
	x=patch_embed(x)
	x=x+pos_embed
	t=timestep_embed(t)
	y=lable_embed(y)
	c=t+y
	for block in blocks:
		x=block(x,c)
	x=final_layer(x,c)
	x=unpatchify(x)
	return x
  • 训练目标
x0= batch["image_or_latent"]
noise=torch.randn_like(x0)
t=torch.randint(
	0,
	num_train_timesteps,
	(x0.shape[0],),
	device=x0.device
)

xt=noise_scheduler.add_noise(x0,noise,t)
noise_pred=dit(xt,t,condition)
loss=F.mse_loss(noise_pred,noise)

DiT里面的classifier-free guidance

DiT 也可以使用 CFG,和 U-Net diffusion 一样。

训练时会随机 drop condition:

if random.random() < drop_prob:
    y = null_label

推理时分别预测:

eps_uncond = model(x_t, t, null_condition)
eps_cond = model(x_t, t, condition)

然后组合:

eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)

含义是:
guidance_scale 越大,越听 condition;
guidance_scale 太大,图像可能过饱和、失真、多样性下降。

DiT 和Unet 的核心区别

  • 输入表示不同

U-Net 输入是 feature map:B × C × H × W

DiT 输入是 token sequence:B × N × D

  • 归纳偏置不同

U-Net 有强图像归纳偏置:
DiT 的归纳偏置弱一些,但可扩展性更强:

  • 条件注入方式不同
    U-Net 通常通过:
    time embedding 加到 ResBlock
    cross-attention 注入文本条件

2. MM‑DiT:多模态扩散 Transformer

2.1 动机

DiT 仅支持单模态输入(图像)和简单的类别条件。对于文本到图像生成,需要模型能同时理解文本和图像两种模态,并实现深层融合。
MM‑DiT 由 Stability AI 在 Stable Diffusion 3 中提出,它将 DiT 扩展为双流 Transformer,让文本和图像 token 在每一层进行双向交互。

在这里插入图片描述

2.2 输入表示

  • 图像流:与 DiT 相同,将带噪图像的潜在表示 z \mathbf{z} z 切分为 patches,得到 token 序列 X I ∈ R N I × d model \mathbf{X}_I \in \mathbb{R}^{N_I \times d_{\text{model}}} XIRNI×dmodel

假设 noisy latent 是:

x_t: [B, C, H, W]

经过 patching 后,会被切成 patch tokens:

patches: [B, N_img, patch_dim]

其中:
N_img = H / patch_size × W / patch_size

然后经过 Linear 投影到 Transformer hidden dimension:

x = Linear(patches) # [B, N_img, D]

再加上 positional embedding:
x = x + pos_embed
x 就是送入 MM-DiT blocks 的 image tokens


  • 文本流:将文本 prompt 通过预训练编码器(如 CLIP 或 T5)编码为 token 特征,再线性投影到相同维度 d model d_{\text{model}} dmodel,得到 X T ∈ R N T × d model \mathbf{X}_T \in \mathbb{R}^{N_T \times d_{\text{model}}} XTRNT×dmodel.

文本分支:Caption 经过三个 text encoder: CLIP-G/14, CLIP-L/14, T5 XXL

CLIP 系列 encoder 通常提供两类信息:

  1. token-level embedding
  2. pooled embedding

这表示两个 CLIP encoder 各自产生一组 token embedding,典型长度是 77 个 token。

可以理解为:

clip_g_tokens: [B, 77, D_g]
clip_l_tokens: [B, 77, D_l]

文本 token 最终会被处理成较高维的 channel 表示,然后通过 Linear 投影到 MMDiT 内部使用的 hidden dimension。

T5 XXL 主要提供更强的自然语言理解能力,尤其是长文本、复杂语义、属性关系、空间关系等。图中 T5 XXL 输出也进入中间的 token embedding 路径。

text_tokens = concat([
    clip_g_token_embeds,
    clip_l_token_embeds,
    t5_token_embeds
], dim=sequence_dim or feature_dim)

c = Linear(text_tokens)

c 可以理解为送入 MM-DiT blocks 的 context tokens / caption tokens。

CLIP 的 pooled output 通常表示整句话的全局语义,例如: 整体风格,主体类别,图像全局语义

pooled feature 经过 MLP 后,会和 timestep encoding 分支融合,得到 y
可以把 y 理解为: 全局调制条件 = 文本全局语义 + 当前 timestep 信息
高噪声阶段:模型更关注整体布局、语义和大结构
低噪声阶段:模型更关注纹理、边缘、文字、细节


  • 时间步条件
    timestep 先通过 sinusoidal positional encoding 变成连续向量, 再经过 MLP, 然后与 pooled text condition 融合

  • 总结

x:image tokens,来自 noised latent
c:caption / context tokens,来自 text encoders
y:global modulation vector,来自 pooled text + timestep
x 负责图像 latent token 的更新
c 负责文本 token 的更新
y 负责调制 attention / MLP / LayerNorm

2.3 MM‑DiT 块结构

每个 MM‑DiT 块同时处理两个流,包含三个子层,每个子层都使用 AdaLN 注入条件:

2.3.1 独立自注意力

对图像流和文本流分别应用自注意力,不跨流:

X I ← X I + SelfAttn ( AdaLN ( X I , γ I 1 , β I 1 ) ) ⋅ α I 1 \mathbf{X}_I \leftarrow \mathbf{X}_I + \text{SelfAttn}(\text{AdaLN}(\mathbf{X}_I, \gamma_{I1}, \beta_{I1})) \cdot \alpha_{I1} XIXI+SelfAttn(AdaLN(XI,γI1,βI1))αI1
X T ← X T + SelfAttn ( AdaLN ( X T , γ T 1 , β T 1 ) ) ⋅ α T 1 \mathbf{X}_T \leftarrow \mathbf{X}_T + \text{SelfAttn}(\text{AdaLN}(\mathbf{X}_T, \gamma_{T1}, \beta_{T1})) \cdot \alpha_{T1} XTXT+SelfAttn(AdaLN(XT,γT1,βT1))αT1

2.3.2 双向交叉注意力

这是 MM‑DiT 的核心创新。两个流互为键/值,实现信息的深度融合:

  • 图像作为查询,文本作为键/值

    X I ← X I + CrossAttn ( AdaLN ( X I ) , X T , X T ) ⋅ α I 2 \mathbf{X}_I \leftarrow \mathbf{X}_I + \text{CrossAttn}(\text{AdaLN}(\mathbf{X}_I), \mathbf{X}_T, \mathbf{X}_T) \cdot \alpha_{I2} XIXI+CrossAttn(AdaLN(XI),XT,XT)αI2

  • 文本作为查询,图像作为键/值

    X T ← X T + CrossAttn ( AdaLN ( X T ) , X I , X I ) ⋅ α T 2 \mathbf{X}_T \leftarrow \mathbf{X}_T + \text{CrossAttn}(\text{AdaLN}(\mathbf{X}_T), \mathbf{X}_I, \mathbf{X}_I) \cdot \alpha_{T2} XTXT+CrossAttn(AdaLN(XT),XI,XI)αT2

这种对称设计使得文本能够“看到”图像当前的状态,从而动态调整文本表示(例如,关注与图像相关的词语)。

2.3.3 独立前馈网络

最后,两个流各自经过一个前馈网络(FFN):

X I ← X I + FFN ( AdaLN ( X I , γ I 3 , β I 3 ) ) ⋅ α I 3 \mathbf{X}_I \leftarrow \mathbf{X}_I + \text{FFN}(\text{AdaLN}(\mathbf{X}_I, \gamma_{I3}, \beta_{I3})) \cdot \alpha_{I3} XIXI+FFN(AdaLN(XI,γI3,βI3))αI3
X T ← X T + FFN ( AdaLN ( X T , γ T 3 , β T 3 ) ) ⋅ α T 3 \mathbf{X}_T \leftarrow \mathbf{X}_T + \text{FFN}(\text{AdaLN}(\mathbf{X}_T, \gamma_{T3}, \beta_{T3})) \cdot \alpha_{T3} XTXT+FFN(AdaLN(XT,γT3,βT3))αT3

2.4 输出与训练

经过 L L L 个 MM‑DiT 块后,仅保留图像流的 token 序列 X I \mathbf{X}_I XI,通过线性层预测噪声(或速度):

ϵ ^ = Linear ( LayerNorm ( X I ) ) \hat{\boldsymbol{\epsilon}} = \text{Linear}(\text{LayerNorm}(\mathbf{X}_I)) ϵ^=Linear(LayerNorm(XI))

文本流的最终状态被丢弃,不作为输出。训练目标与 DiT 相同(预测噪声的 MSE 损失)。在推理时,可以使用 Classifier‑Free Guidance 对文本条件进行强度调节。

2.5 MM-DiT

import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F


# ============================================================
# 1. 工具函数:modulate
# ============================================================

def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    """
    AdaLN-style modulation.

    Args:
        x:     [B, N, D]
        shift: [B, D]
        scale: [B, D]

    Returns:
        [B, N, D]
    """
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


# ============================================================
# 2. RMSNorm:可选的 Q/K norm
# ============================================================

class RMSNorm(nn.Module):
    """
    Root Mean Square LayerNorm.
    常用于 QK normalization,稳定 attention logits。
    """

    def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = True):
        super().__init__()
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))
        else:
            self.register_parameter("weight", None)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [..., D]
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        x = x / rms
        if self.weight is not None:
            x = x * self.weight
        return x


# ============================================================
# 3. MLP
# ============================================================

class MLP(nn.Module):
    """
    Transformer FFN / MLP.
    """

    def __init__(
        self,
        dim: int,
        hidden_dim: Optional[int] = None,
        dropout: float = 0.0,
    ):
        super().__init__()
        hidden_dim = hidden_dim or dim * 4

        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(approximate="tanh"),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


# ============================================================
# 4. Multi-Head Attention:接受已经算好的 Q/K/V
# ============================================================

class JointAttention(nn.Module):
    """
    Joint Attention for MMDiT.

    输入:
        q_c, k_c, v_c: text tokens 的 Q/K/V
        q_x, k_x, v_x: image tokens 的 Q/K/V

    做法:
        1. concat text/image Q/K/V
        2. scaled dot-product attention
        3. split 回 text/image 两部分
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        qk_norm: bool = True,
        attn_dropout: float = 0.0,
        proj_dropout: float = 0.0,
    ):
        super().__init__()

        assert dim % num_heads == 0, "dim must be divisible by num_heads"

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.attn_dropout = attn_dropout
        self.proj_dropout = nn.Dropout(proj_dropout)

        if qk_norm:
            self.q_norm_c = RMSNorm(self.head_dim)
            self.k_norm_c = RMSNorm(self.head_dim)
            self.q_norm_x = RMSNorm(self.head_dim)
            self.k_norm_x = RMSNorm(self.head_dim)
        else:
            self.q_norm_c = nn.Identity()
            self.k_norm_c = nn.Identity()
            self.q_norm_x = nn.Identity()
            self.k_norm_x = nn.Identity()

    def _reshape_heads(self, t: torch.Tensor) -> torch.Tensor:
        """
        [B, N, D] -> [B, H, N, Dh]
        """
        B, N, D = t.shape
        t = t.view(B, N, self.num_heads, self.head_dim)
        t = t.transpose(1, 2)
        return t

    def _merge_heads(self, t: torch.Tensor) -> torch.Tensor:
        """
        [B, H, N, Dh] -> [B, N, D]
        """
        B, H, N, Dh = t.shape
        t = t.transpose(1, 2).contiguous()
        t = t.view(B, N, H * Dh)
        return t

    def forward(
        self,
        q_c: torch.Tensor,
        k_c: torch.Tensor,
        v_c: torch.Tensor,
        q_x: torch.Tensor,
        k_x: torch.Tensor,
        v_x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            q_c, k_c, v_c: [B, N_text, D]
            q_x, k_x, v_x: [B, N_image, D]
            attn_mask:
                可选 attention mask。
                如果使用 PyTorch SDPA,推荐形状可广播到 [B, H, N_total, N_total]。

        Returns:
            out_c: [B, N_text, D]
            out_x: [B, N_image, D]
        """
        B, N_text, D = q_c.shape
        N_image = q_x.shape[1]

        # [B, N, D] -> [B, H, N, Dh]
        q_c = self._reshape_heads(q_c)
        k_c = self._reshape_heads(k_c)
        v_c = self._reshape_heads(v_c)

        q_x = self._reshape_heads(q_x)
        k_x = self._reshape_heads(k_x)
        v_x = self._reshape_heads(v_x)

        # Optional QK norm
        q_c = self.q_norm_c(q_c)
        k_c = self.k_norm_c(k_c)
        q_x = self.q_norm_x(q_x)
        k_x = self.k_norm_x(k_x)

        # Joint attention:
        # [text tokens, image tokens]
        q = torch.cat([q_c, q_x], dim=2)  # [B, H, N_text + N_image, Dh]
        k = torch.cat([k_c, k_x], dim=2)
        v = torch.cat([v_c, v_x], dim=2)

        # PyTorch 2.0 scaled_dot_product_attention
        out = F.scaled_dot_product_attention(
            q,
            k,
            v,
            attn_mask=attn_mask,
            dropout_p=self.attn_dropout if self.training else 0.0,
            is_causal=False,
        )

        # split back
        out_c = out[:, :, :N_text, :]
        out_x = out[:, :, N_text:N_text + N_image, :]

        out_c = self._merge_heads(out_c)  # [B, N_text, D]
        out_x = self._merge_heads(out_x)  # [B, N_image, D]

        out_c = self.proj_dropout(out_c)
        out_x = self.proj_dropout(out_x)

        return out_c, out_x


# ============================================================
# 5. MM-DiT Block
# ============================================================

class MMDiTBlock(nn.Module):
    """
    一个完整的 MM-DiT Block。

    输入:
        c: text / caption tokens, [B, N_text, D]
        x: image / latent tokens, [B, N_image, D]
        y: global condition, [B, D]
           通常来自 timestep embedding + pooled text embedding

    输出:
        c: updated text tokens
        x: updated image tokens

    核心结构:
        text branch:
            LN -> AdaLN modulation -> QKV
        image branch:
            LN -> AdaLN modulation -> QKV

        concat Q/K/V -> joint attention -> split

        text branch:
            output proj -> gate -> residual
            LN -> AdaLN modulation -> MLP -> gate -> residual

        image branch:
            output proj -> gate -> residual
            LN -> AdaLN modulation -> MLP -> gate -> residual
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_norm: bool = True,
        dropout: float = 0.0,
        attn_dropout: float = 0.0,
        norm_eps: float = 1e-6,
    ):
        super().__init__()

        self.dim = dim
        self.num_heads = num_heads
        mlp_hidden_dim = int(dim * mlp_ratio)

        # ----------------------------------------------------
        # y -> modulation parameters
        # 每个分支 6 组参数:
        #   shift_msa, scale_msa, gate_msa,
        #   shift_mlp, scale_mlp, gate_mlp
        # text 和 image 分支各自一套
        # ----------------------------------------------------
        self.adaLN_modulation_c = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, 6 * dim, bias=True),
        )

        self.adaLN_modulation_x = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, 6 * dim, bias=True),
        )

        # ----------------------------------------------------
        # Attention pre-norm
        # ----------------------------------------------------
        self.norm1_c = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=False)
        self.norm1_x = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=False)

        # ----------------------------------------------------
        # Separate QKV projections
        # text 和 image 不共享 QKV 参数
        # ----------------------------------------------------
        self.qkv_c = nn.Linear(dim, 3 * dim, bias=qkv_bias)
        self.qkv_x = nn.Linear(dim, 3 * dim, bias=qkv_bias)

        self.joint_attn = JointAttention(
            dim=dim,
            num_heads=num_heads,
            qk_norm=qk_norm,
            attn_dropout=attn_dropout,
            proj_dropout=dropout,
        )

        # ----------------------------------------------------
        # Separate output projections
        # ----------------------------------------------------
        self.proj_c = nn.Linear(dim, dim)
        self.proj_x = nn.Linear(dim, dim)

        self.proj_drop_c = nn.Dropout(dropout)
        self.proj_drop_x = nn.Dropout(dropout)

        # ----------------------------------------------------
        # MLP sub-layer
        # ----------------------------------------------------
        self.norm2_c = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=False)
        self.norm2_x = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=False)

        self.mlp_c = MLP(dim, hidden_dim=mlp_hidden_dim, dropout=dropout)
        self.mlp_x = MLP(dim, hidden_dim=mlp_hidden_dim, dropout=dropout)

        self.initialize_weights()

    def initialize_weights(self):
        """
        简化版初始化。
        注意:为了模拟 AdaLN-Zero 的稳定训练特点,
        modulation 最后一层可以 zero init,
        这样 block 初始时近似 residual identity。
        """
        nn.init.zeros_(self.adaLN_modulation_c[-1].weight)
        nn.init.zeros_(self.adaLN_modulation_c[-1].bias)

        nn.init.zeros_(self.adaLN_modulation_x[-1].weight)
        nn.init.zeros_(self.adaLN_modulation_x[-1].bias)

    def forward(
        self,
        c: torch.Tensor,
        x: torch.Tensor,
        y: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            c: [B, N_text, D]
            x: [B, N_image, D]
            y: [B, D]
            attn_mask: optional attention mask

        Returns:
            c: [B, N_text, D]
            x: [B, N_image, D]
        """

        # ====================================================
        # 1. 由 y 生成 text/image 两套 modulation 参数
        # ====================================================
        (
            shift_msa_c,
            scale_msa_c,
            gate_msa_c,
            shift_mlp_c,
            scale_mlp_c,
            gate_mlp_c,
        ) = self.adaLN_modulation_c(y).chunk(6, dim=-1)

        (
            shift_msa_x,
            scale_msa_x,
            gate_msa_x,
            shift_mlp_x,
            scale_mlp_x,
            gate_mlp_x,
        ) = self.adaLN_modulation_x(y).chunk(6, dim=-1)

        # ====================================================
        # 2. Attention 子层:LN + Modulation
        # ====================================================
        c_attn_in = modulate(self.norm1_c(c), shift_msa_c, scale_msa_c)
        x_attn_in = modulate(self.norm1_x(x), shift_msa_x, scale_msa_x)

        # ====================================================
        # 3. Separate QKV
        # ====================================================
        q_c, k_c, v_c = self.qkv_c(c_attn_in).chunk(3, dim=-1)
        q_x, k_x, v_x = self.qkv_x(x_attn_in).chunk(3, dim=-1)

        # ====================================================
        # 4. Joint Attention
        # ====================================================
        out_c, out_x = self.joint_attn(
            q_c=q_c,
            k_c=k_c,
            v_c=v_c,
            q_x=q_x,
            k_x=k_x,
            v_x=v_x,
            attn_mask=attn_mask,
        )

        # ====================================================
        # 5. Attention output projection + gate + residual
        # ====================================================
        c = c + gate_msa_c.unsqueeze(1) * self.proj_drop_c(self.proj_c(out_c))
        x = x + gate_msa_x.unsqueeze(1) * self.proj_drop_x(self.proj_x(out_x))

        # ====================================================
        # 6. MLP 子层:LN + Modulation + MLP + gate + residual
        # ====================================================
        c_mlp_in = modulate(self.norm2_c(c), shift_mlp_c, scale_mlp_c)
        x_mlp_in = modulate(self.norm2_x(x), shift_mlp_x, scale_mlp_x)

        c = c + gate_mlp_c.unsqueeze(1) * self.mlp_c(c_mlp_in)
        x = x + gate_mlp_x.unsqueeze(1) * self.mlp_x(x_mlp_in)

        return c, x


# ============================================================
# 6. 简单测试
# ============================================================

if __name__ == "__main__":
    torch.manual_seed(0)

    B = 2
    N_text = 154
    N_image = 1024
    D = 1152
    num_heads = 16

    c = torch.randn(B, N_text, D)
    x = torch.randn(B, N_image, D)
    y = torch.randn(B, D)

    block = MMDiTBlock(
        dim=D,
        num_heads=num_heads,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_norm=True,
        dropout=0.0,
        attn_dropout=0.0,
    )

    c_out, x_out = block(c, x, y)

    print("c_out:", c_out.shape)
    print("x_out:", x_out.shape)

    assert c_out.shape == c.shape
    assert x_out.shape == x.shape

2.6 关键优势

  • 深度双向融合:不同于传统 cross‑attention(图像单向看文本),MM‑DiT 让文本也能关注图像,使得文本特征在生成过程中不断优化,尤其有助于处理复杂提示(如空间关系、多主体、属性绑定)。
  • 长文本支持:自注意力机制天然支持变长序列,可处理数百甚至上千个文本 token,提升文字渲染能力。
  • 可扩展性:与 DiT 类似,MM‑DiT 的参数量可从 800M 扩展到 8B,性能持续提升。

3. 总结对比

特性DiTMM‑DiT
输入模态单模态(图像) + 可选条件(类别)多模态(图像 + 文本)
条件融合AdaLN(将条件调制到层归一化参数)AdaLN + 双向交叉注意力
自注意力仅图像 token图像 token 和文本 token 各自独立自注意力
跨模态交互无(条件仅通过 AdaLN 注入)每一层都进行图像↔文本的双向交叉注意力
主要应用无条件/类别条件图像生成文本到图像生成,后续可扩展为视频生成
代表工作DiT (2023)Stable Diffusion 3 (2024)

两者共同推动了扩散模型向 Transformer 统一架构 的演进,DiT 奠定了技术基础,MM‑DiT 则将其成功应用于工业级多模态生成场景。


参考文献

  • Peebles, W., & Xie, S. (2023). Scalable Diffusion Models with Transformers. ICCV 2023.
  • Esser, P., et al. (2024). Scaling Rectified Flow Transformers for High-Resolution Image Synthesis. *arXiv:2403.03206 *.
于2024年4月-2025年9月期间,研究团队在贵州习水国家级自然保护区制定39条样线,涵盖灌木林、常绿阔叶林、针叶林、常绿落叶阔叶混交林、针阔混交林等不同植被类型,每条样线分春夏秋冬4个季节采集样品,用真菌采集软件记录经纬度、海拔、采集地点、时间、生境等信息,使用佳能相机(R6 mark Ⅱ)对大型真菌进行拍照,并采集标本,标本存放于贵州省生物研究所大型真菌标本馆(HGAMF)。 通过形态学初步鉴定,结合分子生物学最终鉴定,参考已]报道的中国毒蘑菇名录开展毒蘑菇的认定。 调查到保护区内有毒真菌7目25科64种,导致中毒的主要类型有急性肾衰竭型、神经精神型胃肠炎型。最终形成贵州习水国家级自然保护区大型有毒真菌图片数据集,它由以下2个部分组成。 (1)附件1包含78张原始照片(.JPG),照片名字包括了大型有毒真菌的拉丁名中文名,若无中文名的直接用拉丁名。 (2)附件2是一个压缩文件,包含了2张工作表,其中一张表是大型有毒真菌39条样线的信息,另一张表是大型有毒真菌的中毒类型。 照片采用佳能相机R6 mark Ⅱ拍摄,物种鉴定通过多种文献核实,并经两位以上专家鉴定确认。该数据集可为研究地及周边的普通人识别有毒大型真菌提供参考,通过及时的图片对比,能有效避免误采误食大型有毒真菌,同时为因误食大型真菌可能引发的身体损伤进行了总结,能为患者及时治疗提供参考。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

zyw2002

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值