跳到主要内容

RNN 与 LSTM

问题

RNN 的工作原理是什么?LSTM 如何解决长期依赖问题?GRU 与 LSTM 有什么区别?为什么 Transformer 逐渐取代了 RNN?

答案

循环神经网络(Recurrent Neural Network, RNN) 专门处理序列数据——将前一个时间步的隐藏状态作为输入传递给下一个时间步,使网络具有"记忆"。

一、基本 RNN

ht=tanh(Whhht1+Wxhxt+b)h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b) yt=Whyht+byy_t = W_{hy} h_t + b_y

每个时间步 tt,RNN 接收当前输入 xtx_t 和上一步的隐藏状态 ht1h_{t-1},计算新的隐藏状态 hth_t。隐藏状态 hh 就是网络的"记忆"。

基本 RNN 的致命问题长期依赖——当序列很长时,早期信息在经过多步传递后会衰减殆尽(梯度消失)。

二、LSTM(长短期记忆网络)

LSTM 引入了门控机制记忆单元(Cell State) 来解决长期依赖:

三个门的计算:

ft=σ(Wf[ht1,xt]+bf)(遗忘门)f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \quad \text{(遗忘门)} it=σ(Wi[ht1,xt]+bi)(输入门)i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \quad \text{(输入门)} ot=σ(Wo[ht1,xt]+bo)(输出门)o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \quad \text{(输出门)}

记忆单元更新:

C~t=tanh(WC[ht1,xt]+bC)(候选记忆)\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) \quad \text{(候选记忆)} Ct=ftCt1+itC~t(更新记忆)C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \quad \text{(更新记忆)} ht=ottanh(Ct)(输出隐藏状态)h_t = o_t \odot \tanh(C_t) \quad \text{(输出隐藏状态)}
为什么 LSTM 能解决长期依赖?

关键在于记忆单元 CtC_t 的更新方式——它是通过加法(而非乘法)传递的:Ct=ftCt1+itC~tC_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t。加法操作让梯度可以沿着记忆单元直接回传,不会像乘法那样指数衰减。遗忘门控制丢弃什么旧信息,输入门控制写入什么新信息——像一条"高速公路"。

三、GRU(门控循环单元)

GRU 是 LSTM 的简化版本,将三个门简化为两个:

LSTMGRU
门的数量3(遗忘、输入、输出)2(重置、更新)
参数量更多更少(约 LSTM 的 75%)
记忆单元有独立的 Cell State无,直接用 Hidden State
训练速度较慢较快
效果长序列通常略好短序列效果相当

四、RNN 的变体应用

模式输入→输出应用
一对多一个输入 → 序列输出图片描述生成
多对一序列输入 → 一个输出情感分析
多对多(同步)序列 → 同长序列词性标注
多对多(异步)序列 → 不同长序列机器翻译(Seq2Seq)
双向 RNN同时前向和反向处理BERT 式上下文理解

五、为什么 Transformer 取代了 RNN?

维度RNN/LSTMTransformer
并行性❌ 串行处理(t 依赖 t-1)✅ 所有位置并行计算
长距离依赖🟡 LSTM 有所改善但仍衰减✅ 注意力直接连接任意两个位置
训练速度快得多
可扩展性参数难以大规模增长可轻松扩展到万亿参数
GPU 利用率低(串行瓶颈)高(矩阵乘法并行)
RNN/LSTM 的现状

虽然 Transformer 在大多数 NLP 任务上已取代 RNN,但 RNN 仍在以下场景有一席之地:

  • 边缘设备:参数少、推理简单
  • 在线/流式处理:逐步处理输入,内存固定
  • 时序预测:某些时间序列任务中 LSTM 仍表现出色
  • Mamba/SSM:2024 年兴起的状态空间模型可以看作 RNN 的现代版本

常见面试问题

Q1: LSTM 的遗忘门有什么作用?

答案:遗忘门 ft=σ(Wf[ht1,xt])f_t = \sigma(W_f \cdot [h_{t-1}, x_t]) 输出 0~1 之间的值,控制上一步记忆单元 Ct1C_{t-1} 中每个维度保留多少信息。当 ft0f_t \approx 0 时,对应维度的旧信息被"遗忘";当 ft1f_t \approx 1 时,旧信息完全保留。

Q2: Seq2Seq 架构是什么?

答案:Seq2Seq(Sequence-to-Sequence)由编码器 + 解码器组成。编码器(通常是 LSTM)将输入序列压缩为一个固定长度的向量(上下文向量),解码器根据这个向量逐步生成输出序列。最初用于机器翻译,后来加入 Attention 机制,最终演变为 Transformer 的 Encoder-Decoder 架构。

Q3: 双向 RNN 和单向 RNN 的区别?适用场景?

答案

  • 单向 RNN:只从左到右处理序列,只能利用"过去"信息。适用于生成任务(如文本生成——无法看到未来)
  • 双向 RNN:同时从左到右和从右到左处理,能利用完整上下文。适用于理解任务(如情感分析、NER——可以看到完整句子)

BERT 使用双向 Transformer 而非单向,GPT 使用单向 Transformer(Decoder-only),正好对应这两种思路。


相关链接