位置编码 (Positional Encoding)
📌 核心定义 (What)
Section titled “📌 核心定义 (What)”一句话定义:位置编码 (Positional Encoding) 是给每个 Token 注入位置信息的技术。因为 Transformer 的 Attention 机制本身不感知顺序,需要额外的位置信号。
没有位置编码,“我爱你” 和 “你爱我” 对模型来说完全一样!
🏠 生活类比 (Analogy)
Section titled “🏠 生活类比 (Analogy)”📬 “信封地址”
Section titled “📬 “信封地址””想象你收到一堆散落的信件:
- Attention 机制:能看清每封信的内容,但信件散落一地
- 位置编码:每封信都有编号(第 1 封、第 2 封…)
- 组合效果:既知道内容,又知道顺序,才能正确理解
无位置编码:[我] [爱] [你] [北京] → 模型不知道谁爱谁有位置编码:[我₁] [爱₂] [你₃] [北京₄] → 模型知道"我"在前,"你"在后💻 JS 开发者类比
Section titled “💻 JS 开发者类比”// 数组天然有索引(位置信息)const words = ["我", "爱", "你"];words[0]; // "我" - 第 0 位置words[2]; // "你" - 第 2 位置
// 但 Attention 把词当作"无序集合"处理const wordSet = new Set(["我", "爱", "你"]);// Set 中没有顺序概念!
// 位置编码 = 给每个元素加上位置标签const wordsWithPosition = words.map((word, index) => ({ word, position: index, // 或更复杂:position: encodePosition(index)}));// [{ word: "我", position: 0 }, { word: "爱", position: 1 }, ...]📍 交互演示:位置编码 (Interactive)
Section titled “📍 交互演示:位置编码 (Interactive)”观察正弦/余弦位置编码的波形模式,理解为什么每个位置都有唯一的编码。
📍 位置编码可视化
速度:
PE(pos, 2i) = sin(pos / 100002i/d)
PE(pos, 2i+1) = cos(pos / 100002i/d)
0
4
8
12
16
20
24
28
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
-1+1
💡 观察要点:
- 低维度 (左侧): 变化频率高,编码局部位置
- 高维度 (右侧): 变化频率低,编码全局位置
- 每行唯一: 每个位置有独特的编码
- 可外推: 正弦/余弦可推广到未见过的位置
💡 点击位置/维度查看详情 | 蓝色=-1, 红色=+1
🎯 为什么需要它 (Why)
Section titled “🎯 为什么需要它 (Why)”RNN vs Transformer
Section titled “RNN vs Transformer”| 模型 | 如何获得位置信息 |
|---|---|
| RNN | 天然有序:,后面的状态依赖前面 |
| Transformer | 并行计算:所有词同时处理,无内置顺序 |
RNN: [我] → [爱] → [你] (串行,天然有序) ↓ ↓ ↓ h₁ → h₂ → h₃
Transformer: [我] [爱] [你] (并行,需要手动加位置) ↓ ↓ ↓ Self-Attention(不区分位置)没有位置编码会怎样?
Section titled “没有位置编码会怎样?”# 模型会认为这两句话完全相同!sentence1 = "猫 吃 鱼"sentence2 = "鱼 吃 猫"
# Attention 只看"有哪些词",不看"词在哪"# 结果:语义完全混乱📊 数学原理 (Math)
Section titled “📊 数学原理 (Math)”原始 Transformer 的正弦位置编码
Section titled “原始 Transformer 的正弦位置编码”正弦位置编码公式
- : 位置索引(0, 1, 2, …)
- : 维度索引(0, 1, 2, …, d/2)
- : 嵌入维度(如 512)
- 偶数维度用 ,奇数维度用
为什么用正弦/余弦?
Section titled “为什么用正弦/余弦?”- 有界:值永远在 ,不会爆炸
- 周期性:不同频率的波能编码不同尺度的位置关系
- 相对位置: 可以由 线性变换得到
位置 0: [sin(0), cos(0), sin(0), cos(0), ...]位置 1: [sin(1/10000⁰), cos(1/10000⁰), sin(1/10000²), ...]位置 2: [sin(2/10000⁰), cos(2/10000⁰), sin(2/10000²), ...]...
低频维度(i 大):变化慢,捕捉长距离高频维度(i 小):变化快,捕捉短距离💻 代码实现 (Code)
Section titled “💻 代码实现 (Code)”import torchimport math
def positional_encoding(seq_len, d_model): """ 生成正弦位置编码
Args: seq_len: 序列长度 d_model: 嵌入维度
Returns: pe: [seq_len, d_model] 的位置编码矩阵 """ pe = torch.zeros(seq_len, d_model) position = torch.arange(0, seq_len).unsqueeze(1).float() # [seq_len, 1]
# 计算分母:10000^(2i/d) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) # [d_model/2]
# 偶数维度:sin pe[:, 0::2] = torch.sin(position * div_term) # 奇数维度:cos pe[:, 1::2] = torch.cos(position * div_term)
return pe
# 使用示例seq_len = 100d_model = 512
pe = positional_encoding(seq_len, d_model)print(pe.shape) # [100, 512]
# 实际使用:加到词嵌入上# word_embeddings: [batch, seq_len, d_model]# output = word_embeddings + pe[:seq_len, :]import matplotlib.pyplot as pltimport numpy as np
# 生成位置编码pe = positional_encoding(100, 512).numpy()
# 可视化热力图plt.figure(figsize=(15, 5))plt.imshow(pe, aspect='auto', cmap='RdBu')plt.xlabel('Embedding Dimension')plt.ylabel('Position')plt.title('Sinusoidal Positional Encoding')plt.colorbar()plt.show()
# 你会看到:# - 低维度(左边):高频波动,区分相邻位置# - 高维度(右边):低频波动,区分远距离位置// 概念实现:正弦位置编码function positionalEncoding(position, dModel) { const pe = new Array(dModel);
for (let i = 0; i < dModel; i++) { const divTerm = Math.pow(10000, (2 * Math.floor(i / 2)) / dModel);
if (i % 2 === 0) { // 偶数维度:sin pe[i] = Math.sin(position / divTerm); } else { // 奇数维度:cos pe[i] = Math.cos(position / divTerm); } }
return pe;}
// 示例const pe0 = positionalEncoding(0, 8); // 位置 0 的编码const pe1 = positionalEncoding(1, 8); // 位置 1 的编码
console.log('Position 0:', pe0.map(x => x.toFixed(3)));console.log('Position 1:', pe1.map(x => x.toFixed(3)));
// 实际使用:词向量 + 位置编码// finalEmbedding[i] = wordEmbedding[i] + positionalEncoding(i, dModel)🔧 现代位置编码演进
Section titled “🔧 现代位置编码演进”1. 可学习位置编码 (Learnable)
Section titled “1. 可学习位置编码 (Learnable)”BERT、GPT 系列使用。位置编码作为可训练参数。
# PyTorch 实现class LearnablePositionalEncoding(nn.Module): def __init__(self, max_len, d_model): super().__init__() # 位置编码是可学习的参数 self.pe = nn.Parameter(torch.randn(max_len, d_model))
def forward(self, x): seq_len = x.size(1) return x + self.pe[:seq_len]| 优点 | 缺点 |
|---|---|
| 针对任务优化 | 无法处理超出训练长度的序列 |
| 简单直观 | 需要更多参数 |
2. RoPE (Rotary Position Embedding) ⭐
Section titled “2. RoPE (Rotary Position Embedding) ⭐”LLaMA、Qwen、Mistral 等主流模型使用。旋转位置编码,支持长度外推。
核心思想:将位置信息编码为旋转角度, 和 的点积自动包含相对位置 的信息。
import torchimport torch.nn as nnimport math
class RotaryPositionalEmbedding(nn.Module): """RoPE 完整实现(参考 LLaMA)"""
def __init__(self, dim: int, max_seq_len: int = 4096, base: int = 10000): super().__init__() self.dim = dim self.max_seq_len = max_seq_len self.base = base
# 预计算频率 inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq)
# 预计算 cos 和 sin 缓存 self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int): """预计算位置编码缓存""" positions = torch.arange(seq_len, dtype=self.inv_freq.dtype) freqs = torch.outer(positions, self.inv_freq) # [seq_len, dim/2]
# 扩展到完整维度:[seq_len, dim] emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cached", emb.cos()) self.register_buffer("sin_cached", emb.sin())
def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int): """ Args: q, k: [batch, num_heads, seq_len, head_dim] seq_len: 当前序列长度 Returns: q_rot, k_rot: 应用 RoPE 后的 Q 和 K """ cos = self.cos_cached[:seq_len].unsqueeze(0).unsqueeze(0) sin = self.sin_cached[:seq_len].unsqueeze(0).unsqueeze(0)
q_rot = (q * cos) + (self._rotate_half(q) * sin) k_rot = (k * cos) + (self._rotate_half(k) * sin)
return q_rot, k_rot
@staticmethod def _rotate_half(x: torch.Tensor) -> torch.Tensor: """将向量分成两半并旋转:[x1, x2] -> [-x2, x1]""" x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] return torch.cat([-x2, x1], dim=-1)
# 使用示例rope = RotaryPositionalEmbedding(dim=64, max_seq_len=4096)
batch, heads, seq_len, head_dim = 2, 8, 512, 64q = torch.randn(batch, heads, seq_len, head_dim)k = torch.randn(batch, heads, seq_len, head_dim)
q_rot, k_rot = rope(q, k, seq_len)print(f"Q shape: {q_rot.shape}") # [2, 8, 512, 64]# RoPE 的数学本质:2D 旋转矩阵
# 对于位置 m 的向量 [x1, x2],应用旋转角度 θm:# [cos(θm), -sin(θm)] [x1]# [sin(θm), cos(θm)] × [x2]
# 关键性质:# q_m · k_n = f(q, k, m-n)# 点积只依赖相对位置 (m-n),而非绝对位置!
# 不同维度使用不同频率:# θ_i = position / (base^(2i/d))# 低频维度捕捉长距离关系,高频维度捕捉短距离关系3. ALiBi (Attention with Linear Biases)
Section titled “3. ALiBi (Attention with Linear Biases)”BLOOM 使用。不修改嵌入,而是在 Attention 分数上加偏置。
# 核心思想:距离越远,Attention 分数惩罚越大# attention_score[i][j] -= m * |i - j|# m 是每个 head 的斜率参数| 方法 | 外推能力 | 计算效率 | 使用模型 |
|---|---|---|---|
| 正弦 (Sinusoidal) | 理论无限 | ⭐⭐⭐ | 原始 Transformer |
| 可学习 (Learnable) | ❌ 有限 | ⭐⭐⭐ | BERT, GPT-2 |
| RoPE | ✅ 良好 | ⭐⭐ | LLaMA, Qwen |
| ALiBi | ✅ 最佳 | ⭐⭐⭐ | BLOOM |
⚠️ 常见误区 (Pitfalls)
Section titled “⚠️ 常见误区 (Pitfalls)”🔗 相关概念
Section titled “🔗 相关概念”- Transformer 架构 - 位置编码的应用场景
- Attention 机制 - 为什么需要位置编码
- 词嵌入 - 位置编码加到词嵌入上
- 三角函数 - 正弦/余弦数学基础
📚 延伸资源
Section titled “📚 延伸资源”- The Illustrated Transformer - Positional Encoding
- RoPE 论文 - Rotary Position Embedding
- ALiBi 论文 - Train Short, Test Long