1. Vision Transformer(ViT)基础概念
我第一次接触ViT是在2020年,当时Google Research那篇《An Image is Worth 16x16 Words》的论文让我眼前一亮。传统CNN统治计算机视觉领域这么多年,终于有人用纯Transformer结构打破了这种垄断。简单来说,ViT就是把NLP领域的Transformer直接搬到了图像处理上,但这里面有几个关键创新点:
- 图像分块处理:把一张224x224的图片切成16x16的小块(共196个patch),每个patch展平后就是一个"视觉单词"
- 位置编码:因为Transformer本身没有空间位置概念,需要额外添加位置信息
- CLS令牌:借鉴BERT的做法,添加一个特殊标记用于最终分类
我当时在ImageNet上测试ViT-Base模型时发现,当训练数据足够大(比如JFT-300M这种量级),ViT的表现能轻松超过同规模的ResNet。不过在小数据集上,由于缺乏CNN那种局部性归纳偏置(inductive bias),ViT的表现会打折扣。
2. 环境准备与依赖安装
在开始构建ViT之前,我们需要准备好开发环境。我推荐使用Python 3.8+和PyTorch 1.10+的组合,这是目前最稳定的配置。以下是具体步骤:
# 创建conda环境(推荐)
conda create -n vit python=3.8 -y
conda activate vit
# 安装PyTorch
pip install torch torchvision torchaudio
# 安装其他依赖
pip install einops numpy matplotlib tqdm
这里特别要提一下einops这个库,它能让张量操作代码更加清晰。比如传统的permute和reshape操作可以写成更直观的形式:
from einops import rearrange
# 传统写法
x = x.permute(0, 2, 1, 3).contiguous().view(batch, -1, dim)
# 使用einops
x = rearrange(x, 'b h w c -> b (h w) c')
3. 构建ViT核心模块
3.1 多头注意力机制
Transformer的核心就是注意力机制。下面这个实现我优化过多次,应该是目前PyTorch下效率较高的版本:
import torch
import torch.nn as nn
from torch import einsum
class MultiHeadAttention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5

2万+

被折叠的 条评论
为什么被折叠?



