RNN / LSTM / GRU
📌 核心定义 (What)
Section titled “📌 核心定义 (What)”一句话定义:RNN (循环神经网络) 是一种能处理序列数据的神经网络,它通过隐藏状态将信息从前一时刻传递到后一时刻。LSTM 和 GRU 是解决 RNN 长期依赖问题的改进版本。
| 架构 | 年份 | 特点 |
|---|---|---|
| RNN | 1986 | 基础循环结构,有梯度消失问题 |
| LSTM | 1997 | 引入”记忆单元”和三个门,解决长期依赖 |
| GRU | 2014 | LSTM 的简化版,参数更少,效果相近 |
🏠 生活类比 (Analogy)
Section titled “🏠 生活类比 (Analogy)”📖 “读小说 vs 看词典”
Section titled “📖 “读小说 vs 看词典””| 架构 | 类比 |
|---|---|
| 前馈网络 (MLP) | 查词典:每个词独立查询,不考虑上下文 |
| RNN | 读小说:记住前面的情节,理解当前内容 |
| LSTM | 有选择地记笔记:重要情节记下来,无关细节忘掉 |
🚰 “水管传递信息”
Section titled “🚰 “水管传递信息””RNN 的问题:想象一条很长的水管,水从头流到尾会损耗殆尽(梯度消失)。
LSTM 的解决方案:加装”阀门”(门控机制):
- 遗忘门:决定丢掉哪些旧信息
- 输入门:决定存入哪些新信息
- 输出门:决定输出哪些信息
🎬 视频详解 (Video)
Section titled “🎬 视频详解 (Video)”🎨 交互演示 (Interactive)
Section titled “🎨 交互演示 (Interactive)”观察序列数据如何流经 RNN/LSTM,理解隐藏状态的传递和梯度消失问题。
🔄 RNN/LSTM 序列处理可视化
x0: The
RNN
tanh
h0
y0
x1: cat
RNN
tanh
h1
y1
x2: sat
RNN
tanh
h2
y2
x3: on
RNN
tanh
h3
y3
x4: mat
RNN
tanh
h4
y4
💡 隐藏状态 ht 携带之前的信息传递到下一时刻,实现"记忆"
📊 数学原理 (Math)
Section titled “📊 数学原理 (Math)”RNN 基础公式
Section titled “RNN 基础公式”隐藏状态更新
- : 当前时刻的隐藏状态
- : 上一时刻的隐藏状态
- : 当前时刻的输入
- : 权重矩阵
问题:当序列很长时,梯度 需要连乘很多 ,导致:
- → 梯度消失 (记不住远处的信息)
- → 梯度爆炸 (训练不稳定)
LSTM 核心公式
Section titled “LSTM 核心公式”LSTM 通过细胞状态 (Cell State) 和三个门解决长期依赖:
三个门 (Gates)
- (Forget Gate): 决定丢弃多少旧信息, 输出 0~1
- (Input Gate): 决定存入多少新信息
- (Output Gate): 决定输出多少信息
细胞状态更新
- : 候选记忆(新信息)
- : 更新后的细胞状态 = 旧记忆 × 遗忘门 + 新记忆 × 输入门
- : 输出 = 细胞状态 × 输出门
- : 逐元素乘法 (Hadamard product)
GRU 简化公式
Section titled “GRU 简化公式”GRU 将遗忘门和输入门合并为更新门,参数更少:
GRU 公式
- (Update Gate): 决定保留多少旧状态 vs 更新多少新状态
- (Reset Gate): 决定如何结合旧状态计算新候选
- 无独立的细胞状态,结构更简单
💻 代码示例 (PyTorch)
Section titled “💻 代码示例 (PyTorch)”import torchimport torch.nn as nn
# 定义 LSTM 模型class LSTMModel(nn.Module): def __init__(self, input_size, hidden_size, num_layers, output_size): super().__init__() self.lstm = nn.LSTM( input_size=input_size, # 输入特征维度 hidden_size=hidden_size, # 隐藏状态维度 num_layers=num_layers, # 堆叠层数 batch_first=True # 输入形状: (batch, seq, feature) ) self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x): # LSTM 返回: output, (h_n, c_n) out, (h_n, c_n) = self.lstm(x) # 取最后一个时间步的输出 out = self.fc(out[:, -1, :]) return out
# 使用示例model = LSTMModel(input_size=10, hidden_size=64, num_layers=2, output_size=1)x = torch.randn(32, 50, 10) # batch=32, seq_len=50, features=10output = model(x)print(output.shape) # torch.Size([32, 1])📈 LSTM vs GRU vs Transformer
Section titled “📈 LSTM vs GRU vs Transformer”| 特性 | RNN | LSTM | GRU | Transformer |
|---|---|---|---|---|
| 参数量 | 少 | 多 | 中 | 最多 |
| 长序列处理 | ❌ 差 | ✅ 较好 | ✅ 较好 | ✅✅ 最好 |
| 并行训练 | ❌ 不支持 | ❌ 不支持 | ❌ 不支持 | ✅ 支持 |
| 实时处理 | ✅ 好 | ✅ 好 | ✅ 好 | ⚠️ 需要完整输入 |
| 主要应用 | 已过时 | 语音、时序 | 语音、时序 | NLP、CV |
- Attention 机制 - 理解从 Seq2Seq Attention 到 Self-Attention 的演进
- Transformer - 现代 LLM 的基础架构
- Mamba / SSM - 线性复杂度的序列建模新范式