vLLM 论文
PagedAttention 解决 KV Cache 碎片
观察 KV Cache 如何避免重复计算,大幅加速 LLM 推理。
💡 生成 1000 个 token 时,加速比约 500x
一句话定义:KV Cache 是一种缓存优化技术,在 LLM 自回归生成时,将已计算过的 Key 和 Value 矩阵存储起来,避免重复计算,从而大幅加速推理。
核心思想:
想象你在写一篇文章,每写一个字都要重新读一遍前面所有内容:
GPT 类模型是自回归的:生成第 个 token 需要前 个 token 的信息。
没有 KV Cache 的计算量:
有 KV Cache 的计算量:
对于生成 1000 个 token,加速比约为 500 倍!
标准 Self-Attention:
Prefill 阶段(处理 prompt):
输入: [token_1, token_2, ..., token_n]计算: K_cache = [K_1, K_2, ..., K_n] V_cache = [V_1, V_2, ..., V_n]输出: 第一个生成的 tokenDecode 阶段(逐个生成):
新 token_i 进入: Q_new = token_i @ W_Q # 只计算新 token 的 Q K_new = token_i @ W_K # 计算新 K V_new = token_i @ W_V # 计算新 V
K_cache = concat(K_cache, K_new) # 追加到缓存 V_cache = concat(V_cache, V_new)
Attention(Q_new, K_cache, V_cache) # Q_new 与所有 K、V 计算import torchimport torch.nn as nnimport torch.nn.functional as F
class AttentionWithKVCache(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.d_model = d_model self.n_heads = n_heads self.head_dim = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None, use_cache=False): """ x: (batch, seq_len, d_model) - 输入 kv_cache: (k_cache, v_cache) - 缓存的 K, V use_cache: 是否使用和更新缓存 """ B, L, _ = x.shape
# 计算 Q, K, V Q = self.W_q(x) # (B, L, d_model) K = self.W_k(x) V = self.W_v(x)
# 如果有缓存,拼接历史 K, V if kv_cache is not None: k_cache, v_cache = kv_cache K = torch.cat([k_cache, K], dim=1) V = torch.cat([v_cache, V], dim=1)
# 更新缓存 new_cache = (K, V) if use_cache else None
# Reshape for multi-head attention Q = Q.view(B, -1, self.n_heads, self.head_dim).transpose(1, 2) K = K.view(B, -1, self.n_heads, self.head_dim).transpose(1, 2) V = V.view(B, -1, self.n_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
# Causal mask (只看前面的 token) seq_len_k = K.shape[2] seq_len_q = Q.shape[2] mask = torch.triu(torch.ones(seq_len_q, seq_len_k), diagonal=seq_len_k - seq_len_q + 1) mask = mask.bool().to(x.device) scores = scores.masked_fill(mask, float('-inf'))
attn = F.softmax(scores, dim=-1) out = torch.matmul(attn, V)
# Reshape back out = out.transpose(1, 2).contiguous().view(B, -1, self.d_model) out = self.W_o(out)
return out, new_cache
# 使用示例def generate_with_kv_cache(model, prompt_ids, max_new_tokens=50): """ 使用 KV Cache 的生成函数 """ kv_cache = None input_ids = prompt_ids
for _ in range(max_new_tokens): # 如果有缓存,只输入最后一个 token if kv_cache is not None: model_input = input_ids[:, -1:] else: model_input = input_ids
# 前向传播 logits, kv_cache = model(model_input, kv_cache=kv_cache, use_cache=True)
# 采样下一个 token next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) input_ids = torch.cat([input_ids, next_token], dim=1)
# 遇到 EOS 停止 if next_token.item() == eos_token_id: break
return input_idsfrom transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Llama-2-7b-hf"tokenizer = AutoTokenizer.from_pretrained(model_name)model = AutoModelForCausalLM.from_pretrained(model_name)
prompt = "The capital of France is"inputs = tokenizer(prompt, return_tensors="pt")
# 不使用 KV Cache (慢)outputs_slow = model.generate( **inputs, max_new_tokens=50, use_cache=False, # 关闭 KV Cache)
# 使用 KV Cache (快,默认开启)outputs_fast = model.generate( **inputs, max_new_tokens=50, use_cache=True, # 开启 KV Cache)
# 打印生成结果print(tokenizer.decode(outputs_fast[0]))
# 查看 KV Cache 占用# 通过 model.generate 返回的 past_key_values 可以看到缓存形状outputs = model.generate( **inputs, max_new_tokens=10, return_dict_in_generate=True, output_hidden_states=False,)
if hasattr(outputs, 'past_key_values') and outputs.past_key_values: kv = outputs.past_key_values # kv[layer_idx] = (key, value) # key.shape = (batch, num_heads, seq_len, head_dim) print(f"KV Cache 形状: {kv[0][0].shape}")