Skip to content

RNN / LSTM / GRU

一句话定义:RNN (循环神经网络) 是一种能处理序列数据的神经网络,它通过隐藏状态将信息从前一时刻传递到后一时刻。LSTM 和 GRU 是解决 RNN 长期依赖问题的改进版本。

架构年份特点
RNN1986基础循环结构,有梯度消失问题
LSTM1997引入”记忆单元”和三个门,解决长期依赖
GRU2014LSTM 的简化版,参数更少,效果相近

架构类比
前馈网络 (MLP)查词典:每个词独立查询,不考虑上下文
RNN读小说:记住前面的情节,理解当前内容
LSTM有选择地记笔记:重要情节记下来,无关细节忘掉

RNN 的问题:想象一条很长的水管,水从头流到尾会损耗殆尽(梯度消失)。

LSTM 的解决方案:加装”阀门”(门控机制):

  • 遗忘门:决定丢掉哪些旧信息
  • 输入门:决定存入哪些新信息
  • 输出门:决定输出哪些信息


观察序列数据如何流经 RNN/LSTM,理解隐藏状态的传递和梯度消失问题。

🔄 RNN/LSTM 序列处理可视化

x0: The
RNN
tanh
h0
y0
x1: cat
RNN
tanh
h1
y1
x2: sat
RNN
tanh
h2
y2
x3: on
RNN
tanh
h3
y3
x4: mat
RNN
tanh
h4
y4

💡 隐藏状态 ht 携带之前的信息传递到下一时刻,实现"记忆"


隐藏状态更新
ht=tanh(Whhht1+Wxhxt+bh)h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b_h)
  • hth_t: 当前时刻的隐藏状态
  • ht1h_{t-1}: 上一时刻的隐藏状态
  • xtx_t: 当前时刻的输入
  • Whh,WxhW_{hh}, W_{xh}: 权重矩阵

问题:当序列很长时,梯度 Lh0\frac{\partial L}{\partial h_0} 需要连乘很多 WhhW_{hh},导致:

  • W<1|W| < 1梯度消失 (记不住远处的信息)
  • W>1|W| > 1梯度爆炸 (训练不稳定)

LSTM 通过细胞状态 (Cell State) CtC_t三个门解决长期依赖:

三个门 (Gates)
ft=σ(Wf[ht1,xt]+bf)it=σ(Wi[ht1,xt]+bi)ot=σ(Wo[ht1,xt]+bo)\begin{aligned} f_t &= \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \\ i_t &= \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \\ o_t &= \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \end{aligned}
  • ftf_t (Forget Gate): 决定丢弃多少旧信息,σ\sigma 输出 0~1
  • iti_t (Input Gate): 决定存入多少新信息
  • oto_t (Output Gate): 决定输出多少信息
细胞状态更新
C~t=tanh(WC[ht1,xt]+bC)Ct=ftCt1+itC~tht=ottanh(Ct)\begin{aligned} \tilde{C}_t &= \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) \\ C_t &= f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \\ h_t &= o_t \odot \tanh(C_t) \end{aligned}
  • C~t\tilde{C}_t: 候选记忆(新信息)
  • CtC_t: 更新后的细胞状态 = 旧记忆 × 遗忘门 + 新记忆 × 输入门
  • hth_t: 输出 = 细胞状态 × 输出门
  • \odot: 逐元素乘法 (Hadamard product)

GRU 将遗忘门和输入门合并为更新门,参数更少:

GRU 公式
zt=σ(Wz[ht1,xt])rt=σ(Wr[ht1,xt])ht=(1zt)ht1+zttanh(W[rtht1,xt])\begin{aligned} z_t &= \sigma(W_z \cdot [h_{t-1}, x_t]) \\ r_t &= \sigma(W_r \cdot [h_{t-1}, x_t]) \\ h_t &= (1-z_t) \odot h_{t-1} + z_t \odot \tanh(W \cdot [r_t \odot h_{t-1}, x_t]) \end{aligned}
  • ztz_t (Update Gate): 决定保留多少旧状态 vs 更新多少新状态
  • rtr_t (Reset Gate): 决定如何结合旧状态计算新候选
  • 无独立的细胞状态,结构更简单

import torch
import torch.nn as nn
# 定义 LSTM 模型
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super().__init__()
self.lstm = nn.LSTM(
input_size=input_size, # 输入特征维度
hidden_size=hidden_size, # 隐藏状态维度
num_layers=num_layers, # 堆叠层数
batch_first=True # 输入形状: (batch, seq, feature)
)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
# LSTM 返回: output, (h_n, c_n)
out, (h_n, c_n) = self.lstm(x)
# 取最后一个时间步的输出
out = self.fc(out[:, -1, :])
return out
# 使用示例
model = LSTMModel(input_size=10, hidden_size=64, num_layers=2, output_size=1)
x = torch.randn(32, 50, 10) # batch=32, seq_len=50, features=10
output = model(x)
print(output.shape) # torch.Size([32, 1])

特性RNNLSTMGRUTransformer
参数量最多
长序列处理❌ 差✅ 较好✅ 较好✅✅ 最好
并行训练❌ 不支持❌ 不支持❌ 不支持✅ 支持
实时处理✅ 好✅ 好✅ 好⚠️ 需要完整输入
主要应用已过时语音、时序语音、时序NLP、CV

  1. Attention 机制 - 理解从 Seq2Seq Attention 到 Self-Attention 的演进
  2. Transformer - 现代 LLM 的基础架构
  3. Mamba / SSM - 线性复杂度的序列建模新范式