Flash Attention
🎬 视频详解 (Video)
Section titled “🎬 视频详解 (Video)”🎨 交互演示 (Interactive)
Section titled “🎨 交互演示 (Interactive)”对比标准 Attention 和 Flash Attention 的内存使用和计算模式。
⚡ Flash Attention 可视化
Attention 矩阵 (N×N = 8×8)
4×4 = 16 块
标准 Attention
O(N²) = 64
存储完整 8×8 矩阵
Flash Attention
O(B²) = 8
只存储 2×2 块
内存减少
87.5%
进度
0%
当前块
(0, 0)
📖 Flash Attention 核心思想
🐢 标准方法
- 计算完整 N×N 矩阵
- 存储在 HBM (慢)
- 内存: O(N²)
⚡ Flash Attention
- 分块计算,在 SRAM 中
- Online Softmax 算法
- 内存: O(N) 线性!
💡 Flash Attention 2 在 GPT-2 上加速 7.6x
📌 核心定义 (What)
Section titled “📌 核心定义 (What)”一句话定义:Flash Attention 是一种IO 感知的精确注意力算法,通过分块计算和内核融合,将 Attention 的内存复杂度从 O(n²) 降到 O(n),同时速度提升 2-4 倍。
关键点:Flash Attention 不是近似算法,输出与标准 Attention 完全相同,只是计算方式更高效。
版本演进:
- Flash Attention 1 (2022-05-27):arXiv 2205.14135,首次提出,2-4x 加速
- Flash Attention 2 (2023-07-17):arXiv 2307.08691,进一步优化,2x 加速
- Flash Attention 3 (2024-07-11):arXiv 2407.08608,H100 优化,1.5-2x 加速
🏠 生活类比 (Analogy)
Section titled “🏠 生活类比 (Analogy)”📦 “搬家公司”
Section titled “📦 “搬家公司””想象你要整理一个巨大的仓库(GPU 内存):
| 方式 | 类比 | 问题 |
|---|---|---|
| 标准 Attention | 把所有箱子搬到大厅(HBM),一件件整理 | 大厅放不下(内存爆炸) |
| Flash Attention | 一次只搬一小批到工作台(SRAM),整理完再搬下一批 | 工作台小但快,分批处理 |
标准 Attention: GPU HBM (慢但大) ←→ 存储整个 n×n 注意力矩阵
Flash Attention: GPU SRAM (快但小) ←→ 只存储 block_size × block_size 的小块 分块计算,逐块累加,永不存储完整矩阵💻 JS 开发者类比
Section titled “💻 JS 开发者类比”// 标准 Attention:一次性计算整个矩阵(内存爆炸)function standardAttention(Q, K, V) { // 问题:n=100k 时,scores 需要 100k × 100k = 10B 元素! const scores = matmul(Q, transpose(K)); // O(n²) 内存 const weights = softmax(scores); // O(n²) 内存 return matmul(weights, V);}
// Flash Attention:分块计算,流式处理function flashAttention(Q, K, V, blockSize = 256) { const n = Q.length; const output = zeros(n);
// 关键:永远不存储完整的 n×n 矩阵 for (let i = 0; i < n; i += blockSize) { const Q_block = Q.slice(i, i + blockSize);
// 对每个 Q block,遍历所有 K, V blocks let blockOutput = zeros(blockSize); let blockMax = -Infinity; let blockSum = 0;
for (let j = 0; j < n; j += blockSize) { const K_block = K.slice(j, j + blockSize); const V_block = V.slice(j, j + blockSize);
// 只计算一个小块的 attention const localScores = matmul(Q_block, transpose(K_block));
// 在线 softmax:边算边更新(不需要完整矩阵) [blockOutput, blockMax, blockSum] = onlineSoftmaxUpdate( blockOutput, blockMax, blockSum, localScores, V_block ); }
output.set(i, blockOutput); }
return output;}
// 类似 Stream 处理:数据流过,不全部加载到内存// Node.js: fs.createReadStream() 而不是 fs.readFileSync()🎯 为什么需要它 (Why)
Section titled “🎯 为什么需要它 (Why)”GPU 内存层级
Section titled “GPU 内存层级”┌─────────────────────────────────────────────────────┐│ GPU 内存层级 │├─────────────────────────────────────────────────────┤│ SRAM (片上) │ 快 (19 TB/s) │ 小 (20 MB) ││ HBM (显存) │ 慢 (3 TB/s) │ 大 (40-80 GB) ││ CPU RAM │ 很慢 │ 很大 │└─────────────────────────────────────────────────────┘
标准 Attention 的问题: Q, K, V 在 HBM → 计算 scores → 写回 HBM → 读取做 softmax → 写回 HBM 太多 HBM 读写!(IO 成为瓶颈)
Flash Attention 的解决: 分块加载到 SRAM → 计算 + softmax 一次完成 → 只写最终结果到 HBM 大幅减少 IO!| 序列长度 | 标准 Attention 内存 | Flash Attention 内存 |
|---|---|---|
| 2k | 16 MB | 可忽略 |
| 8k | 256 MB | 可忽略 |
| 32k | 4 GB | ~数 MB |
| 128k | 64 GB ❌ (OOM) | ~数 MB ✅ |
📊 数学原理 (Math)
Section titled “📊 数学原理 (Math)”标准 Attention(回顾)
Section titled “标准 Attention(回顾)”标准 Attention
问题:需要存储 ( 矩阵)
Flash Attention 的核心:在线 Softmax
Section titled “Flash Attention 的核心:在线 Softmax”数值稳定的 Softmax
关键洞察:可以分块计算,边算边更新最大值 和累加和
分块更新公式
Section titled “分块更新公式”# 在线 Softmax 更新(Flash Attention 核心)def online_softmax_update(old_output, old_max, old_sum, new_scores, new_V): """ old_output: 之前块的加权和 old_max: 之前块的最大值 old_sum: 之前块的 exp 和 new_scores: 当前块的 QK^T 分数 new_V: 当前块的 V """ new_max = max(new_scores)
# 更新全局最大值 global_max = max(old_max, new_max)
# 调整之前的结果(因为最大值变了) correction = exp(old_max - global_max) old_output = old_output * correction old_sum = old_sum * correction
# 计算当前块的贡献 new_exp = exp(new_scores - global_max) new_sum = sum(new_exp) new_output = matmul(new_exp, new_V)
# 合并 total_sum = old_sum + new_sum total_output = old_output + new_output
return total_output, global_max, total_sum💻 代码实现 (Code)
Section titled “💻 代码实现 (Code)”# PyTorch 2.0+ 原生支持import torchimport torch.nn.functional as F
# 方式 1:使用 scaled_dot_product_attention(自动选择最优实现)q = torch.randn(1, 8, 4096, 64, device='cuda') # [batch, heads, seq, dim]k = torch.randn(1, 8, 4096, 64, device='cuda')v = torch.randn(1, 8, 4096, 64, device='cuda')
# PyTorch 会自动使用 Flash Attention(如果可用)with torch.backends.cuda.sdp_kernel( enable_flash=True, # 启用 Flash Attention enable_math=False, # 禁用标准实现 enable_mem_efficient=False): output = F.scaled_dot_product_attention(q, k, v)
print(output.shape) # [1, 8, 4096, 64]# pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
# 输入形状: [batch, seq_len, num_heads, head_dim]q = torch.randn(1, 4096, 8, 64, device='cuda', dtype=torch.float16)k = torch.randn(1, 4096, 8, 64, device='cuda', dtype=torch.float16)v = torch.randn(1, 4096, 8, 64, device='cuda', dtype=torch.float16)
# Flash Attention 2output = flash_attn_func(q, k, v, causal=True) # causal=True 用于解码器
print(output.shape) # [1, 4096, 8, 64]
# 支持变长序列(更高效)from flash_attn import flash_attn_varlen_func
# cu_seqlens: 累积序列长度,用于批处理不同长度的序列from transformers import AutoModelForCausalLM
# 方式 1:加载时启用 Flash Attention 2model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, attn_implementation="flash_attention_2", # 关键参数 device_map="auto")
# 方式 2:配置文件中指定# config.json: {"attn_implementation": "flash_attention_2"}
# 检查是否启用print(model.config._attn_implementation) # "flash_attention_2"📈 性能对比
Section titled “📈 性能对比”速度提升(A100 GPU)
Section titled “速度提升(A100 GPU)”| 序列长度 | 标准 Attention | Flash Attention 2 | 加速比 |
|---|---|---|---|
| 512 | 1.0x | 1.5x | 1.5x |
| 2048 | 1.0x | 2.5x | 2.5x |
| 8192 | 1.0x | 4.0x | 4.0x |
| 16384 | OOM ❌ | 可运行 ✅ | ∞ |
| 模型 | 批大小 | 标准 Attention | Flash Attention |
|---|---|---|---|
| LLaMA-7B | 1 | 14 GB | 10 GB |
| LLaMA-13B | 1 | 26 GB | 18 GB |
| LLaMA-70B | 1 | OOM | 48 GB ✅ |
🔧 最佳实践
Section titled “🔧 最佳实践”何时使用 Flash Attention
Section titled “何时使用 Flash Attention”# ✅ 推荐使用的场景- 长序列训练/推理 (>2k tokens)- 显存受限的环境- 需要最大化吞吐量
# ⚠️ 不适用的场景- CPU 推理(Flash Attention 仅支持 CUDA)- 需要返回注意力权重矩阵(Flash Attention 不存储)- 非常短的序列 (<512),收益不明显# 1. 数据类型要求# Flash Attention 要求 FP16 或 BF16q = q.half() # 或 q.bfloat16()
# 2. 头维度要求# head_dim 必须是 8 的倍数,推荐 64 或 128
# 3. Causal Mask# 使用 causal=True 比传递 mask 矩阵更高效output = flash_attn_func(q, k, v, causal=True) # ✅output = flash_attn_func(q, k, v, attn_mask=causal_mask) # ❌ 慢⚠️ 常见误区 (Pitfalls)
Section titled “⚠️ 常见误区 (Pitfalls)”🔗 相关概念
Section titled “🔗 相关概念”- Attention 机制 - 理解基础 Attention
- Transformer 架构 - Flash Attention 的应用场景
- Mamba/SSM - 另一种高效序列建模方法
📚 延伸资源
Section titled “📚 延伸资源”- Flash Attention 论文 - FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022)
- Flash Attention 2 论文 - FlashAttention-2: Faster Attention with Better Parallelism (Dao, 2023)
- Flash Attention 3 论文 - FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (Dao et al., 2024)
- flash-attn GitHub - 官方实现
- PyTorch 官方博客 - FlashAttention-3 介绍
- Tri Dao 主页 - 作者论文和博客