Skip to content

Mamba 与状态空间模型 (SSM)

观察 SSM 状态向量如何随每个 token 更新,理解线性复杂度的来源。

🔄 SSM 状态更新可视化

输入序列
The
cat
sat
on
the
mat
ht = A ·ht-1 + B ·xt
状态 = 衰减系数 × 旧状态 + 输入系数 × 当前输入
ht-1
× A
+ B·xt
"The"
=
ht
输出 yt = C · ht
→ 预测下一个 token
4
💡 Mamba 的核心创新
  • 选择性机制: A, B, C 是输入依赖的,让模型能"选择"记住或遗忘
  • 硬件感知算法: 利用 GPU SRAM,避免 HBM 瓶颈
  • 线性复杂度: 处理 100k tokens 和 1k tokens 速度相近

💡 状态向量像"压缩的记忆",随每个 token 更新


一句话定义:Mamba 是基于状态空间模型 (SSM) 的序列建模架构,通过选择性扫描机制实现线性时间复杂度,是 Transformer 在长序列处理上的主要竞争者。

发展历程

  • Mamba-1 (2023-12-01):arXiv 2312.00752,首次展示 SSM 可匹敌 Transformer
  • Mamba-2 (2024-05-31):arXiv 2405.21060,SSD 框架,性能提升
  • Jamba (2024-03-28):AI21 Labs,Mamba + Attention 混合架构

模型类比工作方式
Transformer录像带回放每次回答都要”倒带”看完所有历史,越长越慢
Mamba/SSM实时电视只记住”当前状态”,看多长都一样快
Transformer 处理 100k tokens:
需要计算 100k × 100k = 100 亿次注意力
Mamba 处理 100k tokens:
只需线性扫描,复杂度 O(n),快 1000 倍!

// Transformer 的 Attention:每个词都要看所有其他词
function transformerAttention(tokens) {
const n = tokens.length;
// O(n²) 复杂度!
for (let i = 0; i < n; i++) {
for (let j = 0; j < n; j++) {
computeAttention(tokens[i], tokens[j]);
}
}
}
// Mamba 的 SSM:只维护一个"状态",线性扫描
function mambaSSM(tokens) {
let state = initState(); // 隐藏状态
// O(n) 复杂度!
for (const token of tokens) {
// 状态更新:只依赖当前状态和输入
state = updateState(state, token);
}
return state;
}
// 类似 Array.reduce()
const result = tokens.reduce((state, token) => updateState(state, token), initState());

问题原因影响
二次复杂度Attention 是 O(n²)长文本处理极慢
内存爆炸KV Cache 线性增长128k context 需要 >100GB
推理慢每个 token 都要重算 Attention生成速度受限
优势机制效果
线性复杂度状态空间递推处理 1M tokens 无压力
固定内存状态大小固定不随序列长度增长
快速推理无需 KV Cache生成速度提升 5x

连续 SSM 方程
h(t)=Ah(t)+Bx(t)y(t)=Ch(t)+Dx(t)\begin{aligned} h'(t) &= Ah(t) + Bx(t) \\ y(t) &= Ch(t) + Dx(t) \end{aligned}
  • h(t)h(t): 隐藏状态(类似 RNN 的记忆)
  • x(t)x(t): 输入信号
  • y(t)y(t): 输出信号
  • A,B,C,DA, B, C, D: 可学习的系统矩阵
离散 SSM 方程
hk=Aˉhk1+Bˉxkyk=Chk\begin{aligned} h_k &= \bar{A}h_{k-1} + \bar{B}x_k \\ y_k &= Ch_k \end{aligned}
  • Aˉ=exp(ΔA)\bar{A} = \exp(\Delta A): 离散化的状态转移矩阵
  • Bˉ=(ΔA)1(exp(ΔA)I)ΔB\bar{B} = (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B
  • Δ\Delta: 时间步长(可学习)

关键创新:让 BB, CC, Δ\Delta 依赖于输入,实现内容感知

# 传统 SSM:参数固定
B, C, delta = fixed_params
# Mamba:参数随输入变化(选择性)
B = linear_B(x) # [batch, seq, d_state]
C = linear_C(x) # [batch, seq, d_state]
delta = softplus(linear_delta(x)) # [batch, seq, d_inner]

这使得 Mamba 能像 Attention 一样选择性关注重要信息。


import torch
import torch.nn as nn
class SimpleSSM(nn.Module):
"""简化的 SSM 实现(用于理解原理)"""
def __init__(self, d_model, d_state=16):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# SSM 参数
self.A = nn.Parameter(torch.randn(d_model, d_state))
self.B = nn.Linear(d_model, d_state)
self.C = nn.Linear(d_state, d_model)
self.D = nn.Parameter(torch.ones(d_model))
def forward(self, x):
"""
x: [batch, seq_len, d_model]
"""
batch, seq_len, _ = x.shape
# 初始化隐藏状态
h = torch.zeros(batch, self.d_state, device=x.device)
outputs = []
for t in range(seq_len):
# 状态更新: h_t = A * h_{t-1} + B * x_t
x_t = x[:, t, :] # [batch, d_model]
h = torch.tanh(h @ self.A.T + self.B(x_t))
# 输出: y_t = C * h_t + D * x_t
y_t = self.C(h) + self.D * x_t
outputs.append(y_t)
return torch.stack(outputs, dim=1) # [batch, seq_len, d_model]

指标TransformerMamba说明
时间复杂度O(n²)O(n)Mamba 在长序列上快 100x+
空间复杂度O(n) KV CacheO(1) 状态Mamba 内存恒定
推理吞吐1x5x无需 KV Cache
训练效率1x2-3x并行扫描算法
长序列建模受限于 context理论无限SSM 天然支持
模型架构参数量特点
GPT-5Transformer未公开最强通用能力
Claude 4Transformer未公开超长上下文 (1M+)
Mamba-2Pure SSM130M-2.8B高效长序列
JambaMamba + Attention52B (12B 活跃)HuggingFace
ZambaSSM + Attention7B开源混合

⚠️ GPT-5/Claude 4 参数量未官方公开,市场估计差异大。


┌─────────────────────────────────────────┐
│ Jamba Block │
├─────────────────────────────────────────┤
│ ┌─────────┐ ┌─────────────────────┐ │
│ │ Mamba │ → │ Attention (稀疏) │ │
│ │ Layer │ │ 每 8 层一次 │ │
│ └─────────┘ └─────────────────────┘ │
│ ↓ ↓ │
│ └──────┬───────┘ │
│ ↓ │
│ MoE FFN │
└─────────────────────────────────────────┘

设计理念

  • Mamba 层:处理长距离依赖,高效
  • Attention 层:处理复杂推理,精准
  • MoE:增加容量,控制计算量