Skip to content

Flash Attention


对比标准 Attention 和 Flash Attention 的内存使用和计算模式。

⚡ Flash Attention 可视化

Attention 矩阵 (N×N = 8×8)
4×4 = 16
标准 Attention
O(N²) = 64
存储完整 8×8 矩阵
Flash Attention
O(B²) = 8
只存储 2×2
内存减少
87.5%
进度
0%
当前块
(0, 0)
📖 Flash Attention 核心思想
🐢 标准方法
  • 计算完整 N×N 矩阵
  • 存储在 HBM (慢)
  • 内存: O(N²)
⚡ Flash Attention
  • 分块计算,在 SRAM 中
  • Online Softmax 算法
  • 内存: O(N) 线性!

💡 Flash Attention 2 在 GPT-2 上加速 7.6x


一句话定义:Flash Attention 是一种IO 感知的精确注意力算法,通过分块计算内核融合,将 Attention 的内存复杂度从 O(n²) 降到 O(n),同时速度提升 2-4 倍。

关键点:Flash Attention 不是近似算法,输出与标准 Attention 完全相同,只是计算方式更高效。

版本演进

  • Flash Attention 1 (2022-05-27):arXiv 2205.14135,首次提出,2-4x 加速
  • Flash Attention 2 (2023-07-17):arXiv 2307.08691,进一步优化,2x 加速
  • Flash Attention 3 (2024-07-11):arXiv 2407.08608,H100 优化,1.5-2x 加速

想象你要整理一个巨大的仓库(GPU 内存):

方式类比问题
标准 Attention把所有箱子搬到大厅(HBM),一件件整理大厅放不下(内存爆炸)
Flash Attention一次只搬一小批到工作台(SRAM),整理完再搬下一批工作台小但快,分批处理
标准 Attention:
GPU HBM (慢但大) ←→ 存储整个 n×n 注意力矩阵
Flash Attention:
GPU SRAM (快但小) ←→ 只存储 block_size × block_size 的小块
分块计算,逐块累加,永不存储完整矩阵

// 标准 Attention:一次性计算整个矩阵(内存爆炸)
function standardAttention(Q, K, V) {
// 问题:n=100k 时,scores 需要 100k × 100k = 10B 元素!
const scores = matmul(Q, transpose(K)); // O(n²) 内存
const weights = softmax(scores); // O(n²) 内存
return matmul(weights, V);
}
// Flash Attention:分块计算,流式处理
function flashAttention(Q, K, V, blockSize = 256) {
const n = Q.length;
const output = zeros(n);
// 关键:永远不存储完整的 n×n 矩阵
for (let i = 0; i < n; i += blockSize) {
const Q_block = Q.slice(i, i + blockSize);
// 对每个 Q block,遍历所有 K, V blocks
let blockOutput = zeros(blockSize);
let blockMax = -Infinity;
let blockSum = 0;
for (let j = 0; j < n; j += blockSize) {
const K_block = K.slice(j, j + blockSize);
const V_block = V.slice(j, j + blockSize);
// 只计算一个小块的 attention
const localScores = matmul(Q_block, transpose(K_block));
// 在线 softmax:边算边更新(不需要完整矩阵)
[blockOutput, blockMax, blockSum] = onlineSoftmaxUpdate(
blockOutput, blockMax, blockSum,
localScores, V_block
);
}
output.set(i, blockOutput);
}
return output;
}
// 类似 Stream 处理:数据流过,不全部加载到内存
// Node.js: fs.createReadStream() 而不是 fs.readFileSync()

┌─────────────────────────────────────────────────────┐
│ GPU 内存层级 │
├─────────────────────────────────────────────────────┤
│ SRAM (片上) │ 快 (19 TB/s) │ 小 (20 MB) │
│ HBM (显存) │ 慢 (3 TB/s) │ 大 (40-80 GB) │
│ CPU RAM │ 很慢 │ 很大 │
└─────────────────────────────────────────────────────┘
标准 Attention 的问题:
Q, K, V 在 HBM → 计算 scores → 写回 HBM → 读取做 softmax → 写回 HBM
太多 HBM 读写!(IO 成为瓶颈)
Flash Attention 的解决:
分块加载到 SRAM → 计算 + softmax 一次完成 → 只写最终结果到 HBM
大幅减少 IO!
序列长度标准 Attention 内存Flash Attention 内存
2k16 MB可忽略
8k256 MB可忽略
32k4 GB~数 MB
128k64 GB ❌ (OOM)~数 MB ✅

标准 Attention
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

问题:需要存储 QKTQK^Tn×nn \times n 矩阵)

Flash Attention 的核心:在线 Softmax

Section titled “Flash Attention 的核心:在线 Softmax”
数值稳定的 Softmax
softmax(x)i=eximjexjm,m=max(x)\text{softmax}(x)_i = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}, \quad m = \max(x)

关键洞察:可以分块计算,边算边更新最大值 mm 和累加和

# 在线 Softmax 更新(Flash Attention 核心)
def online_softmax_update(old_output, old_max, old_sum, new_scores, new_V):
"""
old_output: 之前块的加权和
old_max: 之前块的最大值
old_sum: 之前块的 exp 和
new_scores: 当前块的 QK^T 分数
new_V: 当前块的 V
"""
new_max = max(new_scores)
# 更新全局最大值
global_max = max(old_max, new_max)
# 调整之前的结果(因为最大值变了)
correction = exp(old_max - global_max)
old_output = old_output * correction
old_sum = old_sum * correction
# 计算当前块的贡献
new_exp = exp(new_scores - global_max)
new_sum = sum(new_exp)
new_output = matmul(new_exp, new_V)
# 合并
total_sum = old_sum + new_sum
total_output = old_output + new_output
return total_output, global_max, total_sum

# PyTorch 2.0+ 原生支持
import torch
import torch.nn.functional as F
# 方式 1:使用 scaled_dot_product_attention(自动选择最优实现)
q = torch.randn(1, 8, 4096, 64, device='cuda') # [batch, heads, seq, dim]
k = torch.randn(1, 8, 4096, 64, device='cuda')
v = torch.randn(1, 8, 4096, 64, device='cuda')
# PyTorch 会自动使用 Flash Attention(如果可用)
with torch.backends.cuda.sdp_kernel(
enable_flash=True, # 启用 Flash Attention
enable_math=False, # 禁用标准实现
enable_mem_efficient=False
):
output = F.scaled_dot_product_attention(q, k, v)
print(output.shape) # [1, 8, 4096, 64]

序列长度标准 AttentionFlash Attention 2加速比
5121.0x1.5x1.5x
20481.0x2.5x2.5x
81921.0x4.0x4.0x
16384OOM ❌可运行 ✅
模型批大小标准 AttentionFlash Attention
LLaMA-7B114 GB10 GB
LLaMA-13B126 GB18 GB
LLaMA-70B1OOM48 GB ✅

# ✅ 推荐使用的场景
- 长序列训练/推理 (>2k tokens)
- 显存受限的环境
- 需要最大化吞吐量
# ⚠️ 不适用的场景
- CPU 推理(Flash Attention 仅支持 CUDA
- 需要返回注意力权重矩阵(Flash Attention 不存储)
- 非常短的序列 (<512),收益不明显
# 1. 数据类型要求
# Flash Attention 要求 FP16 或 BF16
q = q.half() # 或 q.bfloat16()
# 2. 头维度要求
# head_dim 必须是 8 的倍数,推荐 64 或 128
# 3. Causal Mask
# 使用 causal=True 比传递 mask 矩阵更高效
output = flash_attn_func(q, k, v, causal=True) # ✅
output = flash_attn_func(q, k, v, attn_mask=causal_mask) # ❌ 慢