Skip to content

KV Cache 原理


观察 KV Cache 如何避免重复计算,大幅加速 LLM 推理。

💾 KV Cache 可视化

Token 序列 (生成到第 0 个)
The
quick
brown
fox
jumps
over
the
lazy
dog
Key Cache
Value Cache
无 KV Cache
0
累计 Attention 计算
有 KV Cache
0
累计 Attention 计算
1.0x
📖 原理
无 KV Cache
  • 生成第 n 个 token 需计算 n 次 Attention
  • 总计算量: 1+2+...+n = O(n²)
  • 重复计算大量 K, V
有 KV Cache
  • 缓存已计算的 K, V
  • 新 token 只需计算自己的 Q
  • 总计算量: n 次 = O(n)

💡 生成 1000 个 token 时,加速比约 500x


一句话定义:KV Cache 是一种缓存优化技术,在 LLM 自回归生成时,将已计算过的 Key 和 Value 矩阵存储起来,避免重复计算,从而大幅加速推理。

核心思想:

  • 问题:每生成一个新 token,都要对所有已生成的 token 重新计算 Attention
  • 解决:把之前 token 的 K、V 缓存起来,新 token 只需计算自己的 Q,然后与缓存的 K、V 做 Attention

想象你在写一篇文章,每写一个字都要重新读一遍前面所有内容:

  • 没有 KV Cache:写第 100 个字时,要从第 1 个字重新读到第 99 个字。写第 101 个字时,又要从头读到第 100 个字。极其低效!
  • 有 KV Cache:把前面读过的内容”记住”(缓存),写新字时只需看一眼缓存,不用重新阅读。

GPT 类模型是自回归的:生成第 nn 个 token 需要前 n1n-1 个 token 的信息。

没有 KV Cache 的计算量

  • 生成第 1 个 token:计算 1 次 Attention
  • 生成第 2 个 token:计算 2 次 Attention
  • 生成第 n 个 token:计算 n 次 Attention
  • 总计1+2+...+n=O(n2)1 + 2 + ... + n = O(n^2)

有 KV Cache 的计算量

  • 每个 token 只需计算 1 次,复用之前的 K、V
  • 总计O(n)O(n)

对于生成 1000 个 token,加速比约为 500 倍


标准 Self-Attention:

Self-Attention
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V
  • Q,K,VQ, K, V: 由输入 XX 线性变换得到
  • Q=XWQ,K=XWK,V=XWVQ = XW_Q, \quad K = XW_K, \quad V = XW_V

Prefill 阶段(处理 prompt):

输入: [token_1, token_2, ..., token_n]
计算: K_cache = [K_1, K_2, ..., K_n]
V_cache = [V_1, V_2, ..., V_n]
输出: 第一个生成的 token

Decode 阶段(逐个生成):

新 token_i 进入:
Q_new = token_i @ W_Q # 只计算新 token 的 Q
K_new = token_i @ W_K # 计算新 K
V_new = token_i @ W_V # 计算新 V
K_cache = concat(K_cache, K_new) # 追加到缓存
V_cache = concat(V_cache, V_new)
Attention(Q_new, K_cache, V_cache) # Q_new 与所有 K、V 计算
KV Cache 内存
KV Cache Size=2×L×n×d×b×dtype\text{KV Cache Size} = 2 \times L \times n \times d \times b \times \text{dtype}
  • LL: 层数(如 32)
  • nn: 序列长度(如 4096)
  • dd: 隐藏维度(如 4096)
  • bb: batch size
  • dtype: float16 = 2 bytes
  • 例:32层 × 4096长度 × 4096维 × 2 × 2B ≈ 4GB/样本

import torch
import torch.nn as nn
import torch.nn.functional as F
class AttentionWithKVCache(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None, use_cache=False):
"""
x: (batch, seq_len, d_model) - 输入
kv_cache: (k_cache, v_cache) - 缓存的 K, V
use_cache: 是否使用和更新缓存
"""
B, L, _ = x.shape
# 计算 Q, K, V
Q = self.W_q(x) # (B, L, d_model)
K = self.W_k(x)
V = self.W_v(x)
# 如果有缓存,拼接历史 K, V
if kv_cache is not None:
k_cache, v_cache = kv_cache
K = torch.cat([k_cache, K], dim=1)
V = torch.cat([v_cache, V], dim=1)
# 更新缓存
new_cache = (K, V) if use_cache else None
# Reshape for multi-head attention
Q = Q.view(B, -1, self.n_heads, self.head_dim).transpose(1, 2)
K = K.view(B, -1, self.n_heads, self.head_dim).transpose(1, 2)
V = V.view(B, -1, self.n_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
# Causal mask (只看前面的 token)
seq_len_k = K.shape[2]
seq_len_q = Q.shape[2]
mask = torch.triu(torch.ones(seq_len_q, seq_len_k), diagonal=seq_len_k - seq_len_q + 1)
mask = mask.bool().to(x.device)
scores = scores.masked_fill(mask, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
# Reshape back
out = out.transpose(1, 2).contiguous().view(B, -1, self.d_model)
out = self.W_o(out)
return out, new_cache
# 使用示例
def generate_with_kv_cache(model, prompt_ids, max_new_tokens=50):
"""
使用 KV Cache 的生成函数
"""
kv_cache = None
input_ids = prompt_ids
for _ in range(max_new_tokens):
# 如果有缓存,只输入最后一个 token
if kv_cache is not None:
model_input = input_ids[:, -1:]
else:
model_input = input_ids
# 前向传播
logits, kv_cache = model(model_input, kv_cache=kv_cache, use_cache=True)
# 采样下一个 token
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_token], dim=1)
# 遇到 EOS 停止
if next_token.item() == eos_token_id:
break
return input_ids



vLLM 论文

PagedAttention 解决 KV Cache 碎片

阅读

GQA 论文

Grouped-Query Attention

阅读

LLM 推理优化指南

HuggingFace 文档

阅读