长短期记忆网络(LSTM):遗忘门、输入门、输出门
一句话概述
长短期记忆网络(LSTM)由Hochreiter和Schmidhuber于1997年提出,是解决标准RNN梯度消失问题的最经典方案。LSTM通过引入细胞状态(Cell State)和三个门控机制——遗忘门(Forget Gate)决定丢弃哪些旧信息、输入门(Input Gate)决定存储哪些新信息、输出门(Output Gate)决定输出哪些信息——实现了对长期依赖的有效建模。这三个门的精妙配合使得LSTM可以记住数百个时间步之前的信息,成为NLP领域长达20年的主导架构。
💡 核心要点:①LSTM引入细胞状态C_t,它在时间步间"直线"传递,梯度可以无损通过 ②遗忘门控制从细胞状态中丢弃多少旧信息 ③输入门控制多少新信息写入细胞状态 ④输出门控制从细胞状态中输出多少信息到隐藏状态
教学与演示
一、LSTM的核心思想:门控与细胞状态
是什么(定义):LSTM的核心创新有两个:①细胞状态(Cell State)C_t——一条在时间步间"直线"传递的信息通道,只经过逐元素乘法和加法,梯度几乎可以无损传播。②门控机制——三个Sigmoid门(取值0到1)分别控制信息的遗忘、写入和输出,实现了"选择性记忆"。
大白话 LSTM就是一个"聪明的记事本"——细胞状态C_t是当前页的内容,三个门就像三个管理员:遗忘门决定"擦掉哪些旧内容"(比如前一句的话题已经变了),输入门决定"写下哪些新内容"(比如当前句子的主语),输出门决定"哪些内容可以给别人看"(生成当前时刻的输出)。这个记事本可以保留几十页之前的重要信息。
为什么(原理):标准RNN的梯度消失根本原因在于隐藏状态h_t完全被tanh重新计算,梯度经过W_hh和tanh导数反复衰减。LSTM的细胞状态C_t的更新公式中,C_{t-1}到C_t的路径只经过逐元素乘法(遗忘门f_t)和加法(输入门i_t),没有矩阵乘法和tanh压缩。这意味着梯度可以直接通过C_t传递,几乎不衰减。
怎么做(实现):
import numpy as np
# ========================================
# LSTM 完整实现 —— 三个门控 + 细胞状态
# 核心:「选择性记忆」解决梯度消失
# ========================================
class LSTMCell:
"""
LSTM 单元
输入: x_t (input_size维)
隐藏状态: h_t (hidden_size维)
细胞状态: C_t (hidden_size维)
四个计算步骤(每个都使用Sigmoid或Tanh激活):
1. 遗忘门: f_t = σ(W_f·[h_{t-1}, x_t] + b_f)
2. 输入门: i_t = σ(W_i·[h_{t-1}, x_t] + b_i)
3. 候选值: C̃_t = tanh(W_c·[h_{t-1}, x_t] + b_c)
4. 输出门: o_t = σ(W_o·[h_{t-1}, x_t] + b_o)
"""
def __init__(self, input_size, hidden_size):
"""
初始化LSTM的参数
每个门都有独立的权重矩阵(实际实现中合并为一个大矩阵)
"""
# 遗忘门参数: 决定丢弃多少旧信息
self.W_f = np.random.randn(hidden_size, hidden_size + input_size) * 0.01
self.b_f = np.zeros(hidden_size)
# 输入门参数: 决定写入多少新信息
self.W_i = np.random.randn(hidden_size, hidden_size + input_size) * 0.01
self.b_i = np.zeros(hidden_size)
# 候选细胞状态参数: 生成新的候选信息
self.W_c = np.random.randn(hidden_size, hidden_size + input_size) * 0.01
self.b_c = np.zeros(hidden_size)
# 输出门参数: 决定输出多少信息
self.W_o = np.random.randn(hidden_size, hidden_size + input_size) * 0.01
self.b_o = np.zeros(hidden_size)
def _sigmoid(self, x):
"""Sigmoid: 将值压缩到 (0, 1),用于门控"""
return 1.0 / (1.0 + np.exp(-np.clip(x, -50, 50)))
def forward(self, x_t, h_prev, C_prev):
"""
单个时间步的前向传播
参数:
x_t: 当前输入向量,shape (input_size,)
h_prev: 上一时间步的隐藏状态,shape (hidden_size,)
C_prev: 上一时间步的细胞状态,shape (hidden_size,)
返回:
h_t: 当前隐藏状态
C_t: 当前细胞状态
cache: 中间值(用于反向传播)
"""
# 拼接 [h_{t-1}, x_t] 作为各门的输入
concat = np.concatenate([h_prev, x_t]) # shape: (hidden_size + input_size,)
# ---- 遗忘门: 决定遗忘多少旧信息 ----
# f_t = σ(W_f·concat + b_f),取值 ∈ (0, 1)
# 0=完全遗忘, 1=完全保留
f_t = self._sigmoid(np.dot(self.W_f, concat) + self.b_f)
# ---- 输入门: 决定写入多少新信息 ----
# i_t = σ(W_i·concat + b_i),取值 ∈ (0, 1)
i_t = self._sigmoid(np.dot(self.W_i, concat) + self.b_i)
# ---- 候选细胞状态: 生成新的候选信息 ----
# C̃_t = tanh(W_c·concat + b_c),取值 ∈ (-1, 1)
C_tilde = np.tanh(np.dot(self.W_c, concat) + self.b_c)
# ---- 更新细胞状态 ----
# C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
# ⊙ 表示逐元素乘法(Hadamard积)
# 遗忘门擦除旧信息,输入门写入新信息
C_t = f_t * C_prev + i_t * C_tilde
# ---- 输出门: 决定输出多少信息 ----
# o_t = σ(W_o·concat + b_o),取值 ∈ (0, 1)
o_t = self._sigmoid(np.dot(self.W_o, concat) + self.b_o)
# ---- 隐藏状态 ----
# h_t = o_t ⊙ tanh(C_t)
# 输出门控制从细胞状态中"暴露"多少信息
h_t = o_t * np.tanh(C_t)
cache = (f_t, i_t, C_tilde, o_t, C_t, h_prev, x_t, concat)
return h_t, C_t, cache
# --- 演示 LSTM 的核心机制 ---
print("LSTM 门控机制演示:")
print("=" * 60)
np.random.seed(42)
lstm = LSTMCell(input_size=3, hidden_size=4)
# 模拟一个序列
h = np.zeros(4) # 初始隐藏状态
C = np.zeros(4) # 初始细胞状态
# 序列: 三个词
words = np.array([[1.0, 0.0, 0.0], # 词1
[0.0, 1.0, 0.0], # 词2
[0.0, 0.0, 1.0]]) # 词3
for t, x_t in enumerate(words):
h, C, cache = lstm.forward(x_t, h, C)
f_t, i_t, C_tilde, o_t, _, _, _, _ = cache
print(f"\n t={t}:")
print(f" 遗忘门 f_t: {f_t} ← 哪些旧信息被遗忘?")
print(f" 输入门 i_t: {i_t} ← 哪些新信息被写入?")
print(f" 候选值 C̃_t: {C_tilde} ← 新候选信息是什么?")
print(f" 细胞状态 C_t: {C} ← 更新后的"记忆"")
print(f" 输出门 o_t: {o_t} ← 哪些信息要输出?")
print(f" 隐藏状态 h_t: {h} ← 对外输出的信息")
print(f"\n 关键: 细胞状态 C_t 通过 f_t*C_{t-1} + i_t*C̃_t 更新")
print(f" → 梯度可以通过 C_t 无损传播(无矩阵乘法,无tanh压缩)")
大白话 LSTM的三个门像三个管理员:①遗忘门——"这篇旧笔记有没有用?没用就擦掉"(f_t接近0就擦掉,接近1就保留)。②输入门——"这条新信息重不重要?重要就写进去"(i_t和C̃_t配合决定写什么、写多少)。③输出门——"当前记忆里哪些可以对外说?"(o_t控制从C_t中暴露多少)。细胞状态C_t是核心记忆,只在"遗忘+写入"时线性更新,梯度可以畅通无阻地传播。
什么用(AI关联):LSTM是NLP领域长达20年的主导架构,广泛应用于机器翻译、语音识别、文本生成、情感分析、命名实体识别等。即使今天Transformer在很多任务上超越了LSTM,LSTM在时间序列预测、异常检测、小数据场景中仍然非常有效。
哪些坑(缺点):①参数量大——每个门都有独立的权重矩阵,总参数是标准RNN的4倍。②计算复杂——每个时间步需要4次矩阵乘法。③序列处理仍然是串行的,无法像Transformer一样并行。④在极长序列(>1000步)中仍可能出现梯度问题。
二、遗忘门:选择性遗忘
是什么(定义):遗忘门f_t通过Sigmoid函数输出0到1之间的值,逐元素乘以细胞状态C_{t-1}。当f_t接近0时,对应的信息被"遗忘"(从细胞状态中擦除);当f_t接近1时,对应的信息被"保留"。遗忘门让LSTM能够主动丢弃不再需要的信息。
大白话 遗忘门就是"大脑的清理功能"——读到一段话的第10句时,第1句提到的"张三"已经不重要了,遗忘门就把他相关的记忆擦掉,腾出空间记新信息。没有遗忘门,LSTM的记忆会越来越"满",最后什么都记不住。
为什么(原理):遗忘门的设计是LSTM成功的关键之一。在标准RNN中,旧信息只能被新的计算"覆盖"(通过tanh非线性变换),这不可控。LSTM的遗忘门提供了显式的"删除"机制——网络可以主动选择将某些维度的细胞状态置零。遗忘门初始化为接近1的值(如b_f初始化为1),使网络在训练初期倾向于"先记住、再学会遗忘"。
三、输入门与输出门:写入与暴露
是什么(定义):输入门i_t控制新信息写入细胞状态的程度。它由两部分组成:i_t(Sigmoid门,决定"写多少")和C̃_t(tanh候选值,决定"写什么")。输出门o_t控制细胞状态中多少信息暴露给隐藏状态h_t(进而影响输出和下一时间步的计算)。
大白话 输入门是"要不要把这个新信息记下来"——如果当前词是"今天天气很好",输入门可能把"天气"的信息写入记忆。输出门是"要不要把记忆里的东西说出来"——如果当前需要判断情感,输出门会从记忆中提取"天气""好"等关键信息。
概念关系图谱
| 概念 | 核心含义 | 与AI的关系 | 关联概念 |
|---|---|---|---|
| 细胞状态 | 贯穿时间步的"主线记忆" | 梯度无损传播的关键 | 梯度消失、记忆单元 |
| 遗忘门 | 控制丢弃多少旧记忆 | 选择性遗忘,防止记忆饱和 | Sigmoid、门控机制 |
| 输入门 | 控制写入多少新信息 | 选择性记忆,更新细胞状态 | 候选值、信息筛选 |
| 输出门 | 控制暴露多少记忆 | 决定对外输出什么信息 | 隐藏状态、信息过滤 |
| 门控机制 | Sigmoid门控制信息流 | LSTM/GRU的核心创新 | 选择性通过、0-1权重 |
| 逐元素乘法 | ⊙ (Hadamard积) | 门控的实现方式 | 向量逐元素运算 |
重点答疑
Q1: 为什么LSTM能解决梯度消失,而RNN不能?
关键区别在于梯度传播的路径。RNN中,∂h_t/∂h_{t-1} = diag(tanh')·W_hh^T,经过矩阵乘法和tanh导数压缩,梯度指数级衰减。LSTM中,C_t = f_t⊙C_{t-1} + i_t⊙C̃t,∂C_t/∂C{t-1} = f_t(逐元素),没有矩阵乘法!这意味着梯度可以通过C_t的路径无损传播(当f_t=1时),即使经过100个时间步也不会衰减。同时,遗忘门f_t是可学习的——网络可以学会在需要保持长期记忆的维度上让f_t接近1。
Q2: LSTM和GRU有什么区别?
主要区别:①LSTM有三个门(遗忘、输入、输出)+ 细胞状态,GRU有两个门(重置、更新)+ 无独立细胞状态。②GRU将遗忘门和输入门合并为更新门,用"1-更新门"控制遗忘。③GRU参数量约为LSTM的3/4。④实践上两者性能相近,GRU在小数据上可能更好(参数少),LSTM在大数据上可能略优。选择哪个通常取决于具体任务和调参。
Q3: 遗忘门初始化为1有什么好处?
在LSTM的原始论文中,遗忘门的偏置b_f初始化为0(Sigmoid(0)=0.5)。后来研究发现,将b_f初始化为较大的正值(如1.0,使Sigmoid(1)≈0.73,甚至更大值)可以显著改善长序列上的训练效果。原因:训练初期,网络倾向于"先学会记住",再逐渐学会"选择性遗忘"。如果初始遗忘门就倾向于遗忘,早期训练中梯度难以传播,长期依赖的学习变得困难。
章节单词汇总
| 英文 | 音标 | 术语/释义 |
|---|---|---|
| LSTM | /ɛl ɛs tiː ɛm/ | 长短期记忆网络 |
| Cell State | /sel steɪt/ | 细胞状态,LSTM的长期记忆通道 |
| Forget Gate | /fərˈɡet ɡeɪt/ | 遗忘门,控制丢弃旧信息 |
| Input Gate | /ˈɪnpʊt ɡeɪt/ | 输入门,控制写入新信息 |
| Output Gate | /ˈaʊtpʊt ɡeɪt/ | 输出门,控制暴露信息 |
| Hadamard Product | /hædəˈmɑːrd ˈprɒdʌkt/ | 哈达玛积,逐元素乘法 |
| Gating Mechanism | /ˈɡeɪtɪŋ ˈmekənɪzəm/ | 门控机制,Sigmoid门控制信息流 |
| Candidate Value | /ˈkændɪdət ˈvæljuː/ | 候选值,C̃_t,新信息的候选内容 |
面试练习
Q1 [单选] LSTM中负责"选择性遗忘"的是哪个门?
- A. 遗忘门
- B. 输入门
- C. 输出门
- D. 更新门
解答:遗忘门f_t控制从细胞状态C_{t-1}中丢弃多少旧信息。f_t接近0时遗忘,接近1时保留。
Q2 [单选] LSTM的细胞状态更新公式是什么?
- A. C_t = tanh(W·C_{t-1} + U·x_t)
- B. C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
- C. C_t = C_{t-1} + x_t
- D. C_t = o_t ⊙ C_{t-1}
解答:C_t = f_t⊙C_{t-1} + i_t⊙C̃_t。遗忘门擦除旧信息,输入门写入新信息。
Q3 [单选] LSTM中,梯度可以通过哪个路径几乎无损传播?
- A. 细胞状态 C_t 的路径
- B. 隐藏状态 h_t 的路径
- C. 输入门 i_t 的路径
- D. 输出门 o_t 的路径
解答:C_t = f_t⊙C_{t-1} + ...,∂C_t/∂C_{t-1} = f_t(逐元素),没有矩阵乘法,梯度可无损传播。
Q4 [多选] 以下关于LSTM门控的描述,哪些正确?
- A. 遗忘门使用Sigmoid激活,输出在(0,1)之间
- B. 输入门控制新信息写入细胞状态的程度
- C. 输出门控制从细胞状态中暴露多少信息
- D. 三个门共享同一组权重参数
解答:每个门有独立的权重矩阵(W_f, W_i, W_c, W_o),不共享。A、B、C正确。
Q5 [单选] LSTM中的"⊙"符号表示什么运算?
- A. 矩阵乘法
- B. 逐元素乘法(Hadamard积)
- C. 向量加法
- D. 向量拼接
解答:⊙表示逐元素乘法(Hadamard Product),两个等长向量对应位置相乘。
Q6 [单选] 相比标准RNN,LSTM的参数量大约是多少?
- A. 相同
- B. 2倍
- C. 4倍
- D. 10倍
解答:LSTM有4组权重矩阵(遗忘门、输入门、候选值、输出门),标准RNN只有1组,所以参数量约为4倍。
Q7 [单选] LSTM的遗忘门初始偏置b_f通常设为?
- A. 0
- B. 较大的正值(如1.0)
- C. 较大的负值
- D. 随机值
解答:将b_f初始化为较大的正值(如1.0),使Sigmoid输出接近1,训练初期倾向于"记住",有利于学习长期依赖。
Q8 [多选] LSTM比标准RNN的优势包括哪些?
- A. 更好地处理长期依赖
- B. 通过门控机制控制信息流
- C. 细胞状态提供梯度传播的"高速公路"
- D. 计算速度更快
解答:LSTM计算更慢(4倍参数),但长期依赖建模能力远超RNN。A、B、C正确。
Q9 [单选] 在LSTM中,候选值C̃_t使用什么激活函数?
- A. tanh
- B. Sigmoid
- C. ReLU
- D. Softmax
解答:C̃t = tanh(W_c·[h{t-1}, x_t] + b_c),使用tanh生成(-1,1)的候选值,与输入门i_t(Sigmoid)配合控制写入。
Q10 [单选] LSTM中隐藏状态h_t和细胞状态C_t的关系是什么?
- A. h_t = o_t ⊙ tanh(C_t)
- B. h_t = C_t
- C. h_t = o_t + C_t
- D. h_t = tanh(C_t)(无输出门)
解答:h_t = o_t ⊙ tanh(C_t)。输出门o_t控制从细胞状态中"暴露"多少信息到隐藏状态。h_t是对外输出,C_t是内部记忆。