长短期记忆网络(LSTM):遗忘门、输入门、输出门

一句话概述

长短期记忆网络(LSTM)由Hochreiter和Schmidhuber于1997年提出,是解决标准RNN梯度消失问题的最经典方案。LSTM通过引入细胞状态(Cell State)和三个门控机制——遗忘门(Forget Gate)决定丢弃哪些旧信息、输入门(Input Gate)决定存储哪些新信息、输出门(Output Gate)决定输出哪些信息——实现了对长期依赖的有效建模。这三个门的精妙配合使得LSTM可以记住数百个时间步之前的信息,成为NLP领域长达20年的主导架构。

💡 核心要点:①LSTM引入细胞状态C_t,它在时间步间"直线"传递,梯度可以无损通过 ②遗忘门控制从细胞状态中丢弃多少旧信息 ③输入门控制多少新信息写入细胞状态 ④输出门控制从细胞状态中输出多少信息到隐藏状态

教学与演示

一、LSTM的核心思想:门控与细胞状态

是什么(定义):LSTM的核心创新有两个:①细胞状态(Cell State)C_t——一条在时间步间"直线"传递的信息通道,只经过逐元素乘法和加法,梯度几乎可以无损传播。②门控机制——三个Sigmoid门(取值0到1)分别控制信息的遗忘、写入和输出,实现了"选择性记忆"。

大白话 LSTM就是一个"聪明的记事本"——细胞状态C_t是当前页的内容,三个门就像三个管理员:遗忘门决定"擦掉哪些旧内容"(比如前一句的话题已经变了),输入门决定"写下哪些新内容"(比如当前句子的主语),输出门决定"哪些内容可以给别人看"(生成当前时刻的输出)。这个记事本可以保留几十页之前的重要信息。

为什么(原理):标准RNN的梯度消失根本原因在于隐藏状态h_t完全被tanh重新计算,梯度经过W_hh和tanh导数反复衰减。LSTM的细胞状态C_t的更新公式中,C_{t-1}到C_t的路径只经过逐元素乘法(遗忘门f_t)和加法(输入门i_t),没有矩阵乘法和tanh压缩。这意味着梯度可以直接通过C_t传递,几乎不衰减。

怎么做(实现)

import numpy as np

# ========================================
# LSTM 完整实现 —— 三个门控 + 细胞状态
# 核心:「选择性记忆」解决梯度消失
# ========================================

class LSTMCell:
    """
    LSTM 单元
    输入: x_t (input_size维)
    隐藏状态: h_t (hidden_size维)
    细胞状态: C_t (hidden_size维)
    
    四个计算步骤(每个都使用Sigmoid或Tanh激活):
    1. 遗忘门: f_t = σ(W_f·[h_{t-1}, x_t] + b_f)
    2. 输入门: i_t = σ(W_i·[h_{t-1}, x_t] + b_i)
    3. 候选值: C̃_t = tanh(W_c·[h_{t-1}, x_t] + b_c)
    4. 输出门: o_t = σ(W_o·[h_{t-1}, x_t] + b_o)
    """
    def __init__(self, input_size, hidden_size):
        """
        初始化LSTM的参数
        每个门都有独立的权重矩阵(实际实现中合并为一个大矩阵)
        """
        # 遗忘门参数: 决定丢弃多少旧信息
        self.W_f = np.random.randn(hidden_size, hidden_size + input_size) * 0.01
        self.b_f = np.zeros(hidden_size)
        
        # 输入门参数: 决定写入多少新信息
        self.W_i = np.random.randn(hidden_size, hidden_size + input_size) * 0.01
        self.b_i = np.zeros(hidden_size)
        
        # 候选细胞状态参数: 生成新的候选信息
        self.W_c = np.random.randn(hidden_size, hidden_size + input_size) * 0.01
        self.b_c = np.zeros(hidden_size)
        
        # 输出门参数: 决定输出多少信息
        self.W_o = np.random.randn(hidden_size, hidden_size + input_size) * 0.01
        self.b_o = np.zeros(hidden_size)
    
    def _sigmoid(self, x):
        """Sigmoid: 将值压缩到 (0, 1),用于门控"""
        return 1.0 / (1.0 + np.exp(-np.clip(x, -50, 50)))
    
    def forward(self, x_t, h_prev, C_prev):
        """
        单个时间步的前向传播
        参数:
            x_t: 当前输入向量,shape (input_size,)
            h_prev: 上一时间步的隐藏状态,shape (hidden_size,)
            C_prev: 上一时间步的细胞状态,shape (hidden_size,)
        返回:
            h_t: 当前隐藏状态
            C_t: 当前细胞状态
            cache: 中间值(用于反向传播)
        """
        # 拼接 [h_{t-1}, x_t] 作为各门的输入
        concat = np.concatenate([h_prev, x_t])  # shape: (hidden_size + input_size,)
        
        # ---- 遗忘门: 决定遗忘多少旧信息 ----
        # f_t = σ(W_f·concat + b_f),取值 ∈ (0, 1)
        # 0=完全遗忘, 1=完全保留
        f_t = self._sigmoid(np.dot(self.W_f, concat) + self.b_f)
        
        # ---- 输入门: 决定写入多少新信息 ----
        # i_t = σ(W_i·concat + b_i),取值 ∈ (0, 1)
        i_t = self._sigmoid(np.dot(self.W_i, concat) + self.b_i)
        
        # ---- 候选细胞状态: 生成新的候选信息 ----
        # C̃_t = tanh(W_c·concat + b_c),取值 ∈ (-1, 1)
        C_tilde = np.tanh(np.dot(self.W_c, concat) + self.b_c)
        
        # ---- 更新细胞状态 ----
        # C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
        # ⊙ 表示逐元素乘法(Hadamard积)
        # 遗忘门擦除旧信息,输入门写入新信息
        C_t = f_t * C_prev + i_t * C_tilde
        
        # ---- 输出门: 决定输出多少信息 ----
        # o_t = σ(W_o·concat + b_o),取值 ∈ (0, 1)
        o_t = self._sigmoid(np.dot(self.W_o, concat) + self.b_o)
        
        # ---- 隐藏状态 ----
        # h_t = o_t ⊙ tanh(C_t)
        # 输出门控制从细胞状态中"暴露"多少信息
        h_t = o_t * np.tanh(C_t)
        
        cache = (f_t, i_t, C_tilde, o_t, C_t, h_prev, x_t, concat)
        return h_t, C_t, cache


# --- 演示 LSTM 的核心机制 ---
print("LSTM 门控机制演示:")
print("=" * 60)

np.random.seed(42)
lstm = LSTMCell(input_size=3, hidden_size=4)

# 模拟一个序列
h = np.zeros(4)  # 初始隐藏状态
C = np.zeros(4)  # 初始细胞状态

# 序列: 三个词
words = np.array([[1.0, 0.0, 0.0],   # 词1
                  [0.0, 1.0, 0.0],   # 词2
                  [0.0, 0.0, 1.0]])   # 词3

for t, x_t in enumerate(words):
    h, C, cache = lstm.forward(x_t, h, C)
    f_t, i_t, C_tilde, o_t, _, _, _, _ = cache
    
    print(f"\n  t={t}:")
    print(f"    遗忘门 f_t: {f_t} ← 哪些旧信息被遗忘?")
    print(f"    输入门 i_t: {i_t} ← 哪些新信息被写入?")
    print(f"    候选值 C̃_t: {C_tilde} ← 新候选信息是什么?")
    print(f"    细胞状态 C_t: {C} ← 更新后的"记忆"")
    print(f"    输出门 o_t: {o_t} ← 哪些信息要输出?")
    print(f"    隐藏状态 h_t: {h} ← 对外输出的信息")

print(f"\n  关键: 细胞状态 C_t 通过 f_t*C_{t-1} + i_t*C̃_t 更新")
print(f"  → 梯度可以通过 C_t 无损传播(无矩阵乘法,无tanh压缩)")
LSTM遗忘门\(\mathbf{f}_t = \sigma(\mathbf{W}_f \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f)\)
LSTM输入门与候选值\(\mathbf{i}_t = \sigma(\mathbf{W}_i \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i), \quad \tilde{\mathbf{C}}_t = \tanh(\mathbf{W}_c \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c)\)
LSTM细胞状态更新与输出门\(\mathbf{C}_t = \mathbf{f}_t \odot \mathbf{C}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{C}}_t, \quad \mathbf{o}_t = \sigma(\mathbf{W}_o \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o), \quad \mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{C}_t)\)
大白话 LSTM的三个门像三个管理员:①遗忘门——"这篇旧笔记有没有用?没用就擦掉"(f_t接近0就擦掉,接近1就保留)。②输入门——"这条新信息重不重要?重要就写进去"(i_t和C̃_t配合决定写什么、写多少)。③输出门——"当前记忆里哪些可以对外说?"(o_t控制从C_t中暴露多少)。细胞状态C_t是核心记忆,只在"遗忘+写入"时线性更新,梯度可以畅通无阻地传播。

什么用(AI关联):LSTM是NLP领域长达20年的主导架构,广泛应用于机器翻译、语音识别、文本生成、情感分析、命名实体识别等。即使今天Transformer在很多任务上超越了LSTM,LSTM在时间序列预测、异常检测、小数据场景中仍然非常有效。

哪些坑(缺点):①参数量大——每个门都有独立的权重矩阵,总参数是标准RNN的4倍。②计算复杂——每个时间步需要4次矩阵乘法。③序列处理仍然是串行的,无法像Transformer一样并行。④在极长序列(>1000步)中仍可能出现梯度问题。

二、遗忘门:选择性遗忘

是什么(定义):遗忘门f_t通过Sigmoid函数输出0到1之间的值,逐元素乘以细胞状态C_{t-1}。当f_t接近0时,对应的信息被"遗忘"(从细胞状态中擦除);当f_t接近1时,对应的信息被"保留"。遗忘门让LSTM能够主动丢弃不再需要的信息。

大白话 遗忘门就是"大脑的清理功能"——读到一段话的第10句时,第1句提到的"张三"已经不重要了,遗忘门就把他相关的记忆擦掉,腾出空间记新信息。没有遗忘门,LSTM的记忆会越来越"满",最后什么都记不住。

为什么(原理):遗忘门的设计是LSTM成功的关键之一。在标准RNN中,旧信息只能被新的计算"覆盖"(通过tanh非线性变换),这不可控。LSTM的遗忘门提供了显式的"删除"机制——网络可以主动选择将某些维度的细胞状态置零。遗忘门初始化为接近1的值(如b_f初始化为1),使网络在训练初期倾向于"先记住、再学会遗忘"。

三、输入门与输出门:写入与暴露

是什么(定义):输入门i_t控制新信息写入细胞状态的程度。它由两部分组成:i_t(Sigmoid门,决定"写多少")和C̃_t(tanh候选值,决定"写什么")。输出门o_t控制细胞状态中多少信息暴露给隐藏状态h_t(进而影响输出和下一时间步的计算)。

大白话 输入门是"要不要把这个新信息记下来"——如果当前词是"今天天气很好",输入门可能把"天气"的信息写入记忆。输出门是"要不要把记忆里的东西说出来"——如果当前需要判断情感,输出门会从记忆中提取"天气""好"等关键信息。

概念关系图谱

概念核心含义与AI的关系关联概念
细胞状态贯穿时间步的"主线记忆"梯度无损传播的关键梯度消失、记忆单元
遗忘门控制丢弃多少旧记忆选择性遗忘,防止记忆饱和Sigmoid、门控机制
输入门控制写入多少新信息选择性记忆,更新细胞状态候选值、信息筛选
输出门控制暴露多少记忆决定对外输出什么信息隐藏状态、信息过滤
门控机制Sigmoid门控制信息流LSTM/GRU的核心创新选择性通过、0-1权重
逐元素乘法⊙ (Hadamard积)门控的实现方式向量逐元素运算

重点答疑

Q1: 为什么LSTM能解决梯度消失,而RNN不能?

关键区别在于梯度传播的路径。RNN中,∂h_t/∂h_{t-1} = diag(tanh')·W_hh^T,经过矩阵乘法和tanh导数压缩,梯度指数级衰减。LSTM中,C_t = f_t⊙C_{t-1} + i_t⊙C̃t,∂C_t/∂C{t-1} = f_t(逐元素),没有矩阵乘法!这意味着梯度可以通过C_t的路径无损传播(当f_t=1时),即使经过100个时间步也不会衰减。同时,遗忘门f_t是可学习的——网络可以学会在需要保持长期记忆的维度上让f_t接近1。

Q2: LSTM和GRU有什么区别?

主要区别:①LSTM有三个门(遗忘、输入、输出)+ 细胞状态,GRU有两个门(重置、更新)+ 无独立细胞状态。②GRU将遗忘门和输入门合并为更新门,用"1-更新门"控制遗忘。③GRU参数量约为LSTM的3/4。④实践上两者性能相近,GRU在小数据上可能更好(参数少),LSTM在大数据上可能略优。选择哪个通常取决于具体任务和调参。

Q3: 遗忘门初始化为1有什么好处?

在LSTM的原始论文中,遗忘门的偏置b_f初始化为0(Sigmoid(0)=0.5)。后来研究发现,将b_f初始化为较大的正值(如1.0,使Sigmoid(1)≈0.73,甚至更大值)可以显著改善长序列上的训练效果。原因:训练初期,网络倾向于"先学会记住",再逐渐学会"选择性遗忘"。如果初始遗忘门就倾向于遗忘,早期训练中梯度难以传播,长期依赖的学习变得困难。

章节单词汇总

英文音标术语/释义
LSTM/ɛl ɛs tiː ɛm/长短期记忆网络
Cell State/sel steɪt/细胞状态,LSTM的长期记忆通道
Forget Gate/fərˈɡet ɡeɪt/遗忘门,控制丢弃旧信息
Input Gate/ˈɪnpʊt ɡeɪt/输入门,控制写入新信息
Output Gate/ˈaʊtpʊt ɡeɪt/输出门,控制暴露信息
Hadamard Product/hædəˈmɑːrd ˈprɒdʌkt/哈达玛积,逐元素乘法
Gating Mechanism/ˈɡeɪtɪŋ ˈmekənɪzəm/门控机制,Sigmoid门控制信息流
Candidate Value/ˈkændɪdət ˈvæljuː/候选值,C̃_t,新信息的候选内容

面试练习

Q1 [单选] LSTM中负责"选择性遗忘"的是哪个门?

  • A. 遗忘门
  • B. 输入门
  • C. 输出门
  • D. 更新门
解答:遗忘门f_t控制从细胞状态C_{t-1}中丢弃多少旧信息。f_t接近0时遗忘,接近1时保留。

Q2 [单选] LSTM的细胞状态更新公式是什么?

  • A. C_t = tanh(W·C_{t-1} + U·x_t)
  • B. C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
  • C. C_t = C_{t-1} + x_t
  • D. C_t = o_t ⊙ C_{t-1}
解答:C_t = f_t⊙C_{t-1} + i_t⊙C̃_t。遗忘门擦除旧信息,输入门写入新信息。

Q3 [单选] LSTM中,梯度可以通过哪个路径几乎无损传播?

  • A. 细胞状态 C_t 的路径
  • B. 隐藏状态 h_t 的路径
  • C. 输入门 i_t 的路径
  • D. 输出门 o_t 的路径
解答:C_t = f_t⊙C_{t-1} + ...,∂C_t/∂C_{t-1} = f_t(逐元素),没有矩阵乘法,梯度可无损传播。

Q4 [多选] 以下关于LSTM门控的描述,哪些正确?

  • A. 遗忘门使用Sigmoid激活,输出在(0,1)之间
  • B. 输入门控制新信息写入细胞状态的程度
  • C. 输出门控制从细胞状态中暴露多少信息
  • D. 三个门共享同一组权重参数
解答:每个门有独立的权重矩阵(W_f, W_i, W_c, W_o),不共享。A、B、C正确。

Q5 [单选] LSTM中的"⊙"符号表示什么运算?

  • A. 矩阵乘法
  • B. 逐元素乘法(Hadamard积)
  • C. 向量加法
  • D. 向量拼接
解答:⊙表示逐元素乘法(Hadamard Product),两个等长向量对应位置相乘。

Q6 [单选] 相比标准RNN,LSTM的参数量大约是多少?

  • A. 相同
  • B. 2倍
  • C. 4倍
  • D. 10倍
解答:LSTM有4组权重矩阵(遗忘门、输入门、候选值、输出门),标准RNN只有1组,所以参数量约为4倍。

Q7 [单选] LSTM的遗忘门初始偏置b_f通常设为?

  • A. 0
  • B. 较大的正值(如1.0)
  • C. 较大的负值
  • D. 随机值
解答:将b_f初始化为较大的正值(如1.0),使Sigmoid输出接近1,训练初期倾向于"记住",有利于学习长期依赖。

Q8 [多选] LSTM比标准RNN的优势包括哪些?

  • A. 更好地处理长期依赖
  • B. 通过门控机制控制信息流
  • C. 细胞状态提供梯度传播的"高速公路"
  • D. 计算速度更快
解答:LSTM计算更慢(4倍参数),但长期依赖建模能力远超RNN。A、B、C正确。

Q9 [单选] 在LSTM中,候选值C̃_t使用什么激活函数?

  • A. tanh
  • B. Sigmoid
  • C. ReLU
  • D. Softmax
解答:t = tanh(W_c·[h{t-1}, x_t] + b_c),使用tanh生成(-1,1)的候选值,与输入门i_t(Sigmoid)配合控制写入。

Q10 [单选] LSTM中隐藏状态h_t和细胞状态C_t的关系是什么?

  • A. h_t = o_t ⊙ tanh(C_t)
  • B. h_t = C_t
  • C. h_t = o_t + C_t
  • D. h_t = tanh(C_t)(无输出门)
解答:h_t = o_t ⊙ tanh(C_t)。输出门o_t控制从细胞状态中"暴露"多少信息到隐藏状态。h_t是对外输出,C_t是内部记忆。