ViT 论文
An Image is Worth 16x16 Words
观察 Vision Transformer 如何将图像分割成 patches 并转换为序列。
💡 ViT-B/16: 224×224 图像, 16×16 patches = 196 个 tokens
一句话定义:Vision Transformer (ViT) 将图像切分为 Patch 序列,然后用标准 Transformer 处理,证明了 Transformer 在视觉任务上可以媲美甚至超越 CNN。
| CNN | ViT |
|---|---|
| 用滑动窗口逐步扫描 | 把拼图打散成小块 |
| 局部到全局 | 所有小块同时看 |
| 像用放大镜看画 | 像把画切成格子一起分析 |
ViT 把图片当成”视觉句子”,每个 patch 就是一个”视觉单词”。
| 特性 | CNN | ViT |
|---|---|---|
| 归纳偏置 | 强(局部性、平移不变性) | 弱(更灵活) |
| 数据需求 | 少 | 多(需大规模预训练) |
| 全局建模 | 难(靠深层堆叠) | 强(自注意力) |
| 可扩展性 | 有限 | 强(scaling law) |
| 多模态 | 难统一 | 天然统一架构 |
┌─────────────────────────────────────────────────────────┐│ Vision Transformer │├─────────────────────────────────────────────────────────┤│ 输入图像 (224×224×3) ││ ↓ ││ 切分为 Patches (14×14 个 16×16 patches) ││ ↓ ││ Patch Embedding (线性投影为 D 维向量) ││ ↓ ││ 添加 [CLS] Token + Position Embedding ││ ↓ ││ Transformer Encoder × L 层 ││ ↓ ││ 取 [CLS] 的输出 → MLP Head → 分类 │└─────────────────────────────────────────────────────────┘将图像切分并投影:
与 NLP Transformer 完全相同:
取最后一层 [CLS] token 的表示,通过 MLP 输出类别概率。
from transformers import ViTForImageClassification, ViTImageProcessorfrom PIL import Imageimport requests
# 1. 加载预训练 ViTmodel = ViTForImageClassification.from_pretrained( 'google/vit-base-patch16-224')processor = ViTImageProcessor.from_pretrained( 'google/vit-base-patch16-224')
# 2. 加载图像url = 'https://example.com/cat.jpg'image = Image.open(requests.get(url, stream=True).raw)
# 3. 预处理inputs = processor(images=image, return_tensors="pt")
# 4. 推理outputs = model(**inputs)logits = outputs.logits
# 5. 获取预测类别predicted_class = logits.argmax(-1).item()print(f"Predicted: {model.config.id2label[predicted_class]}")import torchimport torch.nn as nn
class PatchEmbedding(nn.Module): """将图像切分为 patches 并投影"""
def __init__( self, img_size=224, patch_size=16, in_channels=3, embed_dim=768 ): super().__init__() self.num_patches = (img_size // patch_size) ** 2 # 196
# 用 Conv2d 同时完成切分和投影 self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=patch_size, stride=patch_size )
def forward(self, x): # x: [B, C, H, W] x = self.proj(x) # [B, D, H/P, W/P] x = x.flatten(2) # [B, D, N] x = x.transpose(1, 2) # [B, N, D] return x
class ViT(nn.Module): """Vision Transformer"""
def __init__( self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, dropout=0.1 ): super().__init__()
# Patch Embedding self.patch_embed = PatchEmbedding( img_size, patch_size, in_channels, embed_dim ) num_patches = self.patch_embed.num_patches
# CLS Token self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# Position Embedding self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, embed_dim) ) self.dropout = nn.Dropout(dropout)
# Transformer Encoder encoder_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim * mlp_ratio), dropout=dropout, activation='gelu', batch_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, depth)
# Classification Head self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, num_classes)
# 初始化 nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x): B = x.shape[0]
# Patch Embedding x = self.patch_embed(x) # [B, N, D]
# 添加 CLS Token cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat([cls_tokens, x], dim=1) # [B, N+1, D]
# 添加 Position Embedding x = x + self.pos_embed x = self.dropout(x)
# Transformer Encoder x = self.encoder(x)
# 分类 x = self.norm(x[:, 0]) # 取 CLS token x = self.head(x)
return x
# 使用model = ViT(num_classes=1000)img = torch.randn(1, 3, 224, 224)output = model(img) # [1, 1000]| 模型 | 特点 | 参数量 |
|---|---|---|
| ViT-B/16 | Base,patch=16 | 86M |
| ViT-L/16 | Large | 307M |
| ViT-H/14 | Huge,patch=14 | 632M |
| DeiT | 数据高效,知识蒸馏 | 86M |
| Swin | 层级结构,移动窗口 | 88M |
| CLIP ViT | 图文对比学习 | 400M |
| 数据量 | ViT | CNN (ResNet) |
|---|---|---|
| 小数据 (ImageNet 1M) | 稍差 | 更好 |
| 中数据 (ImageNet 21K) | 持平 | 持平 |
| 大数据 (JFT 300M+) | 更好 | 饱和 |
关键洞察: ViT 需要大规模预训练才能发挥优势。