Skip to content

Vision Transformer (ViT)


观察 Vision Transformer 如何将图像分割成 patches 并转换为序列。

👁️ Vision Transformer 可视化

原始图像
224×224 像素
Patches
196
序列长度
197
Patch 嵌入维度
768
ViT-B 维度
768
📖 ViT 核心思想
  • 图像 → 序列: 将图像切分成 patches,类比 NLP 中的 tokens
  • [CLS] Token: 聚合全局信息,用于最终分类
  • 位置编码: 可学习的位置嵌入(与 Transformer 类似)
  • 预训练: 在大规模数据集 (JFT-300M) 上预训练效果最佳

💡 ViT-B/16: 224×224 图像, 16×16 patches = 196 个 tokens


一句话定义:Vision Transformer (ViT) 将图像切分为 Patch 序列,然后用标准 Transformer 处理,证明了 Transformer 在视觉任务上可以媲美甚至超越 CNN。

  • Patch: 图像的小块(如 16×16 像素)
  • Patch Embedding: 将 patch 线性投影为向量
  • CLS Token: 用于分类的特殊 token

CNNViT
用滑动窗口逐步扫描把拼图打散成小块
局部到全局所有小块同时看
像用放大镜看画像把画切成格子一起分析

ViT 把图片当成”视觉句子”,每个 patch 就是一个”视觉单词”。


特性CNNViT
归纳偏置强(局部性、平移不变性)弱(更灵活)
数据需求多(需大规模预训练)
全局建模难(靠深层堆叠)强(自注意力)
可扩展性有限强(scaling law)
多模态难统一天然统一架构
  1. 统一架构: 文本、图像、音频都用 Transformer
  2. 多模态: CLIP、GPT-4V、Gemini 的基础
  3. Scaling: 更大模型 = 更好效果

┌─────────────────────────────────────────────────────────┐
│ Vision Transformer │
├─────────────────────────────────────────────────────────┤
│ 输入图像 (224×224×3) │
│ ↓ │
│ 切分为 Patches (14×14 个 16×16 patches) │
│ ↓ │
│ Patch Embedding (线性投影为 D 维向量) │
│ ↓ │
│ 添加 [CLS] Token + Position Embedding │
│ ↓ │
│ Transformer Encoder × L 层 │
│ ↓ │
│ 取 [CLS] 的输出 → MLP Head → 分类 │
└─────────────────────────────────────────────────────────┘

将图像切分并投影:

Patch Embedding
z0=[xcls;xp1E;xp2E;;xpNE]+Eposz_0 = [x_{\text{cls}}; x_p^1 E; x_p^2 E; \ldots; x_p^N E] + E_{\text{pos}}
  • xpiRP2Cx_p^i \in \mathbb{R}^{P^2 \cdot C}: 第 i 个 patch 展平后的向量
  • ER(P2C)×DE \in \mathbb{R}^{(P^2 \cdot C) \times D}: 投影矩阵
  • EposE_{\text{pos}}: 可学习的位置编码
  • N=HW/P2N = HW/P^2: patch 数量 (如 224/16 = 14, N = 196)

与 NLP Transformer 完全相同:

Transformer Block
zl=MSA(LN(zl1))+zl1zl=MLP(LN(zl))+zlz'_l = \text{MSA}(\text{LN}(z_{l-1})) + z_{l-1} \\ z_l = \text{MLP}(\text{LN}(z'_l)) + z'_l
  • MSA: Multi-head Self-Attention
  • LN: Layer Normalization
  • MLP: Feed-Forward Network
分类头
y=MLP(zL0)y = \text{MLP}(z_L^0)

取最后一层 [CLS] token 的表示,通过 MLP 输出类别概率。


from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image
import requests
# 1. 加载预训练 ViT
model = ViTForImageClassification.from_pretrained(
'google/vit-base-patch16-224'
)
processor = ViTImageProcessor.from_pretrained(
'google/vit-base-patch16-224'
)
# 2. 加载图像
url = 'https://example.com/cat.jpg'
image = Image.open(requests.get(url, stream=True).raw)
# 3. 预处理
inputs = processor(images=image, return_tensors="pt")
# 4. 推理
outputs = model(**inputs)
logits = outputs.logits
# 5. 获取预测类别
predicted_class = logits.argmax(-1).item()
print(f"Predicted: {model.config.id2label[predicted_class]}")

模型特点参数量
ViT-B/16Base,patch=1686M
ViT-L/16Large307M
ViT-H/14Huge,patch=14632M
DeiT数据高效,知识蒸馏86M
Swin层级结构,移动窗口88M
CLIP ViT图文对比学习400M

数据量ViTCNN (ResNet)
小数据 (ImageNet 1M)稍差更好
中数据 (ImageNet 21K)持平持平
大数据 (JFT 300M+)更好饱和

关键洞察: ViT 需要大规模预训练才能发挥优势。




ViT 论文

An Image is Worth 16x16 Words

阅读

CLIP 论文

图文对比学习

阅读

timm 库

PyTorch Image Models

GitHub