Cosmos3 模型架构(一):Wan VAE(WAE)

Cosmos3 模型架构:Wan VAE(WAE)

本文描述 Cosmos3 使用的 Wan VAEAutoencoderKLWan / WanEncoder3d)的原理、内部结构、各 layer 连接关系与典型输入输出 shape。
源码:diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py


零、Wan VAE 原理概述

0.1 在整体栈中的角色

WAE(Wan VAE)是一个 3D 视频变分自编码器(KL-VAE):把高维像素视频压进 低维、低分辨率的 latent 时空网格,DiT 等生成模型在 latent 里做扩散 / flow matching,最后再 decode 回像素。

像素视频  ←──decode──  latent(紧凑语义 + 时空结构)  ←──DiT 去噪──  噪声
   │                           ▲
   └──────── encode ───────────┘

没有 VAE,DiT 就要直接在 [B, 3, T, H, W] 上工作,算力和显存都不可接受。WAE 是 像素世界与生成模型之间的 codec——类似 JPEG 之于图像 CNN,但是 可学习、3D、因果、面向生成 的版本。

0.2 抽象流程(Encode → Latent → Decode)

Decode

Latent

Encode

patchify(可选)

WanEncoder3d
因果 3D 卷积 + 逐级下采样

quant_conv

高斯参数 μ, logσ

z ~ N(μ,σ) 或 argmax(μ)
Pipeline 里常取 μ

post_quant_conv

WanDecoder3d
对称上采样

unpatchify(可选)

像素 [B,3,T,H,W]

重建像素

Encode 四步:

  1. patchify(Wan 2.2):2×2 空间 patch 并入通道,H, W 减半、通道变 12,相当于先做一层「无损重排」。
  2. Encoder:多层 3D 卷积金字塔,空间约 /8、时间约 /4,输出 z_dim×2 通道特征图。
  3. quant_conv:1×1×1 卷积,整理成 VAE 的高斯参数。
  4. 采样 / 取均值:训练时从 N(μ, σ) 采样;Cosmos3 推理通常 直接取 μ(argmax),再减 mean、除 std 归一化后交给 DiT。

Decode 完全对称:post_quant_conv → Decoder 上采样 → unpatchify → clamp 到 [-1, 1]

0.3 五条核心设计原理

(1)时空联合压缩,而非「逐帧 2D VAE」

WAE 用 3D 因果卷积 同时在 T、H、W 上下采样:

维度典型压缩比(相对原像素)机制
空间×16patchify ×2 + encoder ×8
时间×4后两个 down stage 做 temporal downsample

17 帧像素 → 5 步 latent,即 (T - 1) // 4 + 1。latent 每一步对应约 4 帧像素的信息,不是 1:1 帧对齐

(2)因果性:流式编码 + 自回归友好

所有时间卷积都是 因果的(只看过去帧)。配合 分块 encode(首块 1 帧,后续每块 4 帧)和 feat_cache,长视频不必一次灌进网络,chunk 间通过 cache 传递上下文,结果与整段等价。

这同时支持:用过去帧 encode 观测,用未来 latent 步做 rollout 预测——Cosmos3 Policy 的时间语义就建立在这上面。

(3)变分瓶颈:连续、平滑的 latent 空间

KL-VAE 训练目标大致是:

重建损失(像素接近原视频) + β · KL(q(z|x) ‖ N(0, I))

效果:latent 空间 连续、可插值,适合扩散 / flow matching;DiT 学的是「在这个空间里从噪声走向数据分布」。Pipeline 里对 latent 做 per-channel mean/std 归一化,是把 VAE 输出对齐到 DiT 训练时的数值范围。

(4)多尺度金字塔 + 低分辨率 attention

Encoder 是经典 U-Net encoder 半支

  • dim_mult 控制 stage 数通道宽度(越深越宽)。
  • 每个 stage:若干 ResBlock + 下采样(Wan 2.2 用带 AvgDown3D shortcut 的 DownBlock)。
  • MidBlock 在最深层加一次 逐帧 2D self-attention,在低分辨率上捕获全局空间关系,成本可控。

Decoder 结构对称,保证 encode/decode 可逆(有损但高质量重建)。

(5)Wan 2.2 residual down block
Wan 2.1Wan 2.2(Cosmos3)
Down stageResBlock 串行 + Resample主路 + AvgDown3D shortcut 相加
作用标准 VAE 金字塔shortcut 保留 pooling 信息,主路学残差修正

这是架构版本差异(is_residual=True),必须与 checkpoint 匹配。详见 §三、关键配置参数

0.4 信息如何被「压扁」

原像素空间                          Latent 空间
┌─────────────────┐                ┌──┐
│ 3 ch            │   Encoder      │48│  ← 更少通道
│ T=17 帧         │  ──────────►   │T'=5│ ← 更少时间步
│ H×W 全分辨率    │                │H/16×W/16│ ← 更小空间
└─────────────────┘                └──┘
     高维、冗余大                      低维、DiT 友好

DiT 在 latent 上做 patchify(如 2×2)后再进 Transformer,是 第二层 token 化;VAE 的 patchify 是 进入 encoder 前 的通道重排,两者不要混。

0.5 在 Cosmos3 中的角色

阶段WAE 干什么
Policy 预处理首帧(+ padding 窗口)→ encode → z₀ 作为视觉 condition
DiT 去噪同一 latent 空间 里联合预测未来 latent 步 + action
可视化未来 latent → decode → rollout 视频

WAE 不参与扩散迭代本身;只在 encode/decode 边界出现。

0.6 与经典 2D SD VAE 的对比

SD 2D VAEWan VAE
输入单张图视频片段
卷积2D3D 因果
时间/4 压缩,因果分块
Latent4 ch, H/816–48 ch, T/4, H/16
推理一次 forwardchunk + cache 流式

一、在 AutoencoderKLWan 中的位置

Cosmos3 在调用 WanEncoder3d 之前 可能先做 patchify(Wan 2.2):

像素视频 [B, 3, T, H, W]
    │  patchify(patch_size=2)  ← 仅 Wan 2.2 / Cosmos3
    ▼
Encoder 输入 [B, 12, T, H/2, W/2]
    │  WanEncoder3d
    ▼
[B, z_dim×2, T', H'', W'']     # conv_out 输出 z_dim×2(如 96)
    │  quant_conv (1×1×1)
    ▼
高斯参数 [B, z_dim×2, T', H'', W'']

时间 / 空间压缩比(相对原始像素):

维度公式配置项
时间T' = (T - 1) // 4 + 1scale_factor_temporal = 4
空间H'' = H // 16, W'' = W // 16patchify /2 × encoder /8

Policy 示例(Wan 2.1 风格,in_channels=3,无 patch):T=17, H=480, W=640

阶段Shape
输入[1, 3, 17, 480, 640]
conv_in 后[1, 96, 17, 480, 640]
DownBlock-0 后[1, 96, 17, 240, 320]
DownBlock-1 后[1, 192, 9, 120, 160]
DownBlock-2 后[1, 384, 5, 60, 80]
DownBlock-3 后[1, 384, 5, 60, 80]
conv_out 后[1, 32, 5, 60, 80]

二、WanEncoder3d 整体拓扑

Cosmos3 使用 Wan 2.2 residual 路径is_residual=True)。

Cosmos3 Wan 2.2 典型配置:

参数
base_dim160
dim_mult[1, 2, 4, 4]
z_dim48(conv_out 输出 z_dim×2 = 96
in_channels12(patchify 后)
num_res_blocks2
temperal_downsample[False, True, True]
is_residualTrue

Head

Tail

norm_out (RMSNorm)

SiLU

conv_out
WanCausalConv3d(3×3×3)
640 → z_dim×2 (=96)

Mid Block

WanResidualBlock 640→640

WanAttentionBlock 640
逐帧 2D self-attn

WanResidualBlock 640→640

DownBlock-3: 640→640

2× WanResidualBlock
640→640→640
无下采样

DownBlock-2: 320→640

2× WanResidualBlock
320→640→640

WanResample downsample3d
空间 /2 + 时间 /2

AvgDown3D shortcut
factor_t=2, factor_s=2

+

DownBlock-1: 160→320

2× WanResidualBlock
160→320→320

WanResample downsample3d
空间 /2 + 时间 /2

AvgDown3D shortcut
factor_t=2, factor_s=2

+

DownBlock-0: 160→160

2× WanResidualBlock
160→160→160

WanResample downsample2d
空间 /2, 时间不变

AvgDown3D shortcut
factor_s=2

+

Input x
[B, C_in, T, H, W]
C_in=12 (patch后) 或 3

conv_in
WanCausalConv3d(3×3×3)
C_in → 160

Output
[B, 96, T', H/8, W/8]
相对 encoder 输入

通道维度计算:

dims = [base_dim * m for m in [1] + dim_mult]
# Cosmos3: base_dim=160, dim_mult=[1,2,4,4]
# → dims = [160, 160, 320, 640, 640]

# Policy (Wan 2.1 风格): base_dim=96
# → dims = [96, 96, 192, 384, 384]

4 个 WanResidualDownBlock 对应 (dims[i] → dims[i+1])i = 0, 1, 2, 3


三、关键配置参数

3.1 dim_mult:通道倍率与 stage 数量

dim_mult每个下采样 stage 相对 base_dim(代码里参数名 dim)的通道倍率列表,同时决定 encoder 有多少个 down stage。

源码:

dims = [base_dim * m for m in [1] + dim_mult]
# 前缀 [1] 表示 conv_in 输出通道 = base_dim × 1

base_dim=160, dim_mult=[1, 2, 4, 4] 为例:

索引计算通道 C对应模块
160×1160conv_in 输出
stage 0160×1160DownBlock-0:160→160
stage 1160×2320DownBlock-1:160→320
stage 2160×4640DownBlock-2:320→640
stage 3160×4640DownBlock-3:640→640

要点:

  • len(dim_mult) = DownBlock 个数(默认 4),也等于 空间下采样 stage 数(最后一级通常不再下采样分辨率,只做通道变换)。
  • dim_mult[i] 越大,该 stage 输出通道越多,表达能力越强,参数量与计算量也越大。
  • 末尾两个 4, 4 表示 最后两级保持 640 通道不再扩宽,只在最深层做特征提炼(DownBlock-3 无 Resample)。
  • dims 相邻两项 (dims[i], dims[i+1]) 就是第 i 个 DownBlock 的 (in_dim, out_dim)

temperal_downsample 的关系:temperal_downsample 长度通常等于 len(dim_mult),逐 stage 控制该级是否做 时间 下采样;空间 下采样由 down_flag=(i != len(dim_mult)-1) 决定(除最后一级外每级 /2)。

Decoder 侧对称:dims = [base_dim * m for m in [dim_mult[-1]] + dim_mult[::-1]],通道随上采样逐级收窄。

3.2 is_residual:Wan 2.1 vs 2.2 架构开关

is_residual 决定 down/up stage 用哪套模块,不是 ResBlock 内部 x + h 那条小残差。

is_residual=False(Wan 2.1)is_residual=True(Wan 2.2 / Cosmos3)
Encoder stageResBlock × NWanResample(串行扁平列表)一个 WanResidualDownBlock(主路 + shortcut)
Decoder stageWanUpBlockWanResidualUpBlock(对称,用 DupUp3D
下采样 shortcutAvgDown3D 与主路输出相加

如何选择: 必须与 checkpoint 权重一致,不能随意切换。

模型is_residual
Wan 2.1 VAEFalseAutoencoderKLWan 默认)
Wan 2.2 VAE / Cosmos3True(见 convert_wan_to_diffusers.pyvae22_diffusers_config

两种拓扑的 state_dict key 与层结构不同,混用会导致 load 失败或 silent wrong。这是 架构版本开关,不是推理超参。

3.3 num_res_blocks vs DownBlock 个数

容易混淆的两个「层数」:

概念由什么决定默认值含义
DownBlock 个数len(dim_mult)4分辨率 stage 数;每 stage 最多一次空间 /2
每 DownBlock 内 ResBlock 数num_res_blocks2该 stage 主路上叠多少个 WanResidualBlock
# WanResidualDownBlock 内部
for _ in range(num_res_blocks):
    resnets.append(WanResidualBlock(in_dim, out_dim, dropout))

num_res_blocks 只影响 每个 stage 内的卷积深度;改 dim_mult 的长度或数值则改变 stage 总数、通道宽度、下采样次数

3.4 WanResidualDownBlock 的作用

Wan 2.2 将每个 down stage 封装为一个 带 parallel shortcut 的下采样块

输入 x [B, C_in, T, H, W]
         │
    x_copy ──────────────────────► AvgDown3D ──────────────┐
         │                    (pool + 通道重组)              │
         │                                                  (+)
         ▼                                                   │
    ResBlock × num_res_blocks                                │
         ▼                                                   │
    [WanResample]  ← 仅 down_flag=True                       │
         └──────────────────────────────────────────────────┘
                              │
                              ▼
                    输出 [B, C_out, T', H', W']
  • 主路(learned path)WanResidualBlock 做 3D 特征变换,末尾 WanResample 做可学习的时空下采样。
  • Shortcut(AvgDown3D:无卷积,对输入做时空分组平均 + 通道重组,输出 shape 与主路对齐;factor_s=2 空间 /2factor_t=2 时间 /2
  • 相加return x + self.avg_shortcut(x_copy) — 类似 ResNet / SDXL VAE 的 down block,shortcut 提供恒等信息的「安全通路」,主路主要学 与 pooling 的残差,训练更稳、重建通常更好。

AvgDown3D 核心逻辑(autoencoder_kl_wan.py):按 factor_t × factor_s² 分组,通道扩成 C×factor 再 group mean 到 C_out

Decoder 对称:WanResidualUpBlock + DupUp3D(repeat + reshape 上采样)。


四、子模块内部连接

4.1 WanResidualBlock

每个 DownBlock 内含 2 个 ResBlock:

x [B,C,T,H,W]

conv_shortcut
1×1×1 (C≠C'时)

RMSNorm → SiLU

CausalConv3d 3×3×3
C → C'

RMSNorm → SiLU → Dropout

CausalConv3d 3×3×3
C' → C'

+

out [B,C',T,H,W]

4.2 WanResidualDownBlock

主路径 + shortcut 残差相加:

x_copy ──────────────────────────► AvgDown3D ──┐
  │                                             ├──► (+) ──► out
  └──► ResBlock × num_res_blocks ──► [Downsample] ──┘

源码(autoencoder_kl_wan.py):

return x + self.avg_shortcut(x_copy)

4.3 WanResample 下采样

mode空间时间内部操作
downsample2d/2不变逐帧 ZeroPad2d + Conv2d stride=2
downsample3d/2/2先做 2d down,再 CausalConv3d(3,1,1) stride=(2,1,1) 合并相邻帧

各 DownBlock 的下采样配置(temperal_downsample=[False, True, True]):

Block通道downsample mode空间时间
0160→160downsample2d/2不变
1160→320downsample3d/2/2
2320→640downsample3d/2/2
3640→640不变不变

4.4 WanMidBlock

ResBlock(640→640)
  → WanAttentionBlock(640)   # [B,C,T,H,W] 展成 (B×T) 个 [C,H,W] 做单头 SDPA
  → ResBlock(640→640)

4.5 WanCausalConv3d

  • 在时间维上因果:padding 只在过去帧侧
  • 推理时配合 feat_cache 跨 chunk 传递上下文

五、各 Stage Shape 对照表

Cosmos3 Wan 2.2 为例:原始像素 [B, 3, 17, 512, 512]

Stage模块CTHW备注
patchify 后1217256256相对原图像 /2
conv_inCausalConv3d16017256256
DownBlock-0Res×2 + down2d16017128128仅空间 /2
DownBlock-1Res×2 + down3d320~96464空间+时间 /2
DownBlock-2Res×2 + down3d640~53232空间+时间 /2
DownBlock-3Res×2640~53232无下采样
mid_blockRes+Attn+Res640~53232
conv_outCausalConv3d96~53232z_dim×2=96

相对原始像素T'≈5H'=512/16=32W'=512/16=32


六、因果分块编码(实际推理路径)

AutoencoderKLWan._encode 不是一次性喂 17 帧,而是因果分块 + feat_cache

iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
    if i == 0:
        out = self.encoder(x[:, :, :1, :, :], feat_cache=..., feat_idx=...)
    else:
        out_ = self.encoder(
            x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
            feat_cache=..., feat_idx=...,
        )
        out = torch.cat([out, out_], 2)
Chunk-0: [B,C, 1,H,W]  ──encoder──► [B,96, 1,h,w]
Chunk-1: [B,C, 4,H,W]  ──encoder──► [B,96, 1,h,w]  ─┐
Chunk-2: [B,C, 4,H,W]  ──encoder──► [B,96, 1,h,w]  ─┼─ cat(dim=2) ──► [B,96, 5,h,w]
Chunk-3: [B,C, 4,H,W]  ──encoder──► [B,96, 1,h,w]  ─┤
Chunk-4: [B,C, 4,H,W]  ──encoder──► [B,96, 1,h,w]  ─┘

每个 WanCausalConv3d 在时间维上只看过去帧;feat_cache 在 chunk 间传递上下文,保证与整段视频一次 encode 等价。


七、参数量(参考)

配置WanEncoder3d 参数量
Policy 风格 base_dim=96, z_dim=1653.6M
Cosmos3 Wan 2.2 base_dim=160, z_dim=48~150M

完整 VAE(encoder + decoder + quant conv)Policy 风格约 127–132M


八、关键源码索引

组件文件行号(约)
WanEncoder3dautoencoder_kl_wan.py509–628
WanResidualDownBlock同上473–506
WanResidualBlock同上315–386
WanResample同上224–312
WanMidBlock同上434–470
WanCausalConv3d同上131–173
AutoencoderKLWan._encode同上1133–1158
patchify同上917–937
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

self-motivation

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值