RNN的梯度问题与解决方案

一句话概述

RNN虽然在理论上能建模任意长度的序列依赖,但BPTT(时间反向传播)中Jacobian矩阵的连乘会导致梯度在时间维度上呈指数衰减(梯度消失)或指数增长(梯度爆炸),使标准RNN的实际有效记忆长度不足20-30步。解决这一核心问题需要从三个层面入手:①工程层面——梯度裁剪直接限制梯度范数,防止爆炸;②架构层面——LSTM/GRU通过门控机制创建梯度传播的"高速公路";③训练层面——权重初始化、层归一化、截断BPTT和残差连接等技巧从旁辅助。

💡 核心要点:①BPTT中梯度 = Jacobian矩阵的连乘,当特征值<1时指数衰减→梯度消失,特征值>1时指数增长→梯度爆炸 ②梯度裁剪是按比例缩放梯度范数,不改变方向,是防爆炸的基础工程手段 ③LSTM的细胞状态通过加法连接(而非乘法连接)实现了梯度的跨时间步直通 ④GRU通过更新门z_t的(1-z_t)项同样提供梯度高速公路 ⑤Xavier初始化、层归一化、截断BPTT从训练技巧层面辅助稳定训练

教学与演示

一、BPTT与梯度问题的数学根源

是什么(定义):BPTT(Backpropagation Through Time)是RNN的反向传播算法。它将RNN在时间维度上展开为T层的"虚拟前馈网络"(每个时间步对应一层,这些层共享同一组权重参数),然后应用链式法则从最后一个时间步逐层反向传播梯度。由于参数在所有时间步共享,最终每个参数的梯度等于所有时间步上该参数梯度的累加和。BPTT是理解RNN梯度问题的理论起点。

为什么(原理):在标准的BPTT中,损失L对早期时间步隐藏状态h₁的梯度需要经过T-1次链式法则连乘。每一次连乘都涉及Jacobian矩阵∂hₜ/∂h_{t-1},其谱半径(最大特征值的绝对值)决定了梯度的命运。如果谱半径<1,连乘T次后梯度以指数速度(λᵀ)趋近于零——这就是梯度消失;如果谱半径>1,梯度以指数速度发散——这就是梯度爆炸。实际中,激活函数tanh的导数在饱和区接近0,加上初始化不当时权重矩阵Wₕₕ的谱半径通常<1,梯度消失远比梯度爆炸更常见。

怎么做(实现)

import numpy as np

# ========================================
# BPTT 梯度传播模拟 —— 直观演示梯度消失/爆炸
# 梯度 = Jacobian矩阵在时间维度上的连乘
# ========================================

def demo_bptt_gradient(T=50, eigenvalue=0.9):
    """
    模拟BPTT中梯度在T个时间步中的变化
    参数:
        T: 序列长度/时间步数
        eigenvalue: Jacobian矩阵 ∂h_t/∂h_{t-1} 的主特征值
    返回:
        gradient_magnitude: T步后的梯度规模
    """
    # BPTT核心: 梯度 ≈ (Jacobian特征值)^T
    # ∂L/∂h_1 = ∂L/∂h_T · ∏_{t=2}^{T} ∂h_t/∂h_{t-1}
    # 连乘结果近似于 λ^T
    gradient_magnitude = np.power(eigenvalue, T)
    return gradient_magnitude


def analyze_gradient_behavior():
    """
    分析不同特征值和序列长度下的梯度行为
    """
    print("BPTT 梯度传播分析:")
    print("=" * 60)

    # 案例1: 特征值0.9 —— 梯度消失
    cases = [
        (0.9, "梯度消失: 特征值<1,连乘后指数衰减"),
        (1.0, "梯度稳定: 特征值=1,梯度无损传递"),
        (1.1, "梯度爆炸: 特征值>1,连乘后指数增长"),
    ]

    seq_lengths = [10, 20, 50, 100]
    for eigenvalue, desc in cases:
        print(f"\n{desc}")
        print(f"  特征值 λ = {eigenvalue}")
        for T in seq_lengths:
            mag = demo_bptt_gradient(T, eigenvalue)
            print(f"    序列长度 T={T}: 梯度规模 = {mag:.2e}")
            if mag < 1e-10:
                print(f"      → 梯度几乎为零,早期输入对参数更新无影响")
            elif mag > 1e6:
                print(f"      → 梯度爆炸,参数更新剧烈震荡甚至NaN")


def compute_jacobian_trace(W_hh, hidden_size=8):
    """
    计算∂h_t/∂h_{t-1} Jacobian矩阵的谱半径
    这是BPTT梯度行为的决定性因素
    参数:
        W_hh: 隐藏到隐藏的循环权重矩阵
        hidden_size: 隐藏状态维度
    返回:
        spectral_radius: 谱半径(最大特征值的绝对值)
    """
    np.random.seed(42)
    # 模拟W_hh
    if W_hh is None:
        W_hh = np.random.randn(hidden_size, hidden_size) * 0.1

    # 计算特征值
    eigenvalues = np.linalg.eigvals(W_hh)
    spectral_radius = np.max(np.abs(eigenvalues))

    print(f"\n\nJacobian谱半径分析:")
    print("=" * 60)
    print(f"  W_hh 谱半径: {spectral_radius:.4f}")
    print(f"  所有特征值: {[f'{ev:.4f}' for ev in eigenvalues]}")
    print(f"\n  如果谱半径 < 1: 梯度沿时间步指数衰减 → 梯度消失")
    print(f"  如果谱半径 > 1: 梯度沿时间步指数增长 → 梯度爆炸")
    print(f"  如果谱半径 = 1: 梯度稳定传递(理想情况)")

    return spectral_radius


# 运行演示
analyze_gradient_behavior()
print("\n")
compute_jacobian_trace(None)
时间步连乘导致梯度问题\(\frac{\partial L}{\partial \mathbf{h}_1} = \frac{\partial L}{\partial \mathbf{h}_T} \cdot \prod_{t=2}^{T} \frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_{t-1}}\)
Jacobian矩阵的具体形式\(\frac{\partial \mathbf{h}_t}{\partial \mathbf{h}_{t-1}} = \text{diag}\left(1 - \tanh^2(\mathbf{z}_t)\right) \cdot \mathbf{W}_{hh}^\top\)
大白话 BPTT就像是传话游戏——每经过一个时间步,信息就被乘以一个因子。如果因子<1,传到开头信息几乎没了(梯度消失);如果因子>1,传到开头信息炸了(梯度爆炸)。更糟糕的是,tanh激活函数在输入很大或很小时导数接近0,这个"信息打折扣"是双重的——W_hh的特征值和tanh的导数都在削弱梯度。

什么用(AI关联):理解BPTT和梯度问题的数学根源是设计LSTM/GRU的理论基础。LSTM和GRU的所有门控设计归根结底都是为了打破"连乘导致指数衰减"这个因果链。①LSTM的细胞状态使用加法连接而非乘法连接,使∂cₜ/∂c_{t-1} ≈ f_t(遗忘门),而非tanh'·W_hh^T这种复杂的连乘。②GRU的更新门通过h_t = (1-z_t)⊙h_{t-1} + z_t⊙h̃t,使得∂hₜ/∂h{t-1}中也包含了可控的(1-z_t)直通项。理解BPTT的连乘本质,才能真正理解门控单元"为什么不直接乘,而要用加法和门控"。

哪些坑(缺点):①简单RNN的最大有效序列长度约为20-30步——超过这个范围,梯度消失使得早期时间步的信息几乎无法影响参数更新。②"梯度消失"不是说梯度真的变成精确的零(机器精度限制),而是说其量级衰减到对学习率来说几乎不起作用。③梯度消失和爆炸可以在同一网络中交替发生——某些参数爆炸而另一些消失。④截断BPTT(限制反向传播的时间步数)可以减轻计算负担,但同时也牺牲了长距离依赖的学习能力。

二、梯度裁剪——梯度爆炸的第一道防线

是什么(定义):梯度裁剪(Gradient Clipping)是一种简单但极其有效的工程技术:在每次参数更新前,先计算所有参数梯度的全局L2范数||g||。如果||g||超过预设的阈值C,则按比例将所有梯度缩放:g' = (C/||g||)·g。这样梯度方向不变,但范数被限制在C以内。梯度裁剪是处理梯度爆炸最直接、部署成本最低的方案,几乎所有RNN训练(包括LSTM/GRU)都会默认启用。

为什么(原理):梯度爆炸发生时,单个参数更新步长过大,可能"跳过"最优解甚至导致权重发散为NaN。梯度裁剪相当于给出一个"最大步长限制"——无论BPTT算出多大的梯度,参数每步的更新量都不会超过某个安全阈值。从优化几何的角度看,梯度裁剪在损失曲面上限定了优化器每一步的"跳跃距离",防止了在陡峭区域的大幅震荡。值得注意的是,梯度裁剪不解决梯度消失问题——它只限制上限,不管下限。

怎么做(实现)

import numpy as np

# ========================================
# 梯度裁剪 —— 限制梯度范数,防止爆炸
# 不改变梯度方向,只限制其长度
# ========================================

def gradient_clip(gradients, max_norm=5.0):
    """
    全局梯度裁剪(L2范数)
    参数:
        gradients: 梯度列表,每个元素是一个numpy数组
        max_norm: 裁剪阈值,梯度范数的上限
    返回:
        clipped_gradients: 裁剪后的梯度列表
        total_norm: 原始梯度的总L2范数
    """
    # 计算所有参数梯度的全局L2范数
    # total_norm = sqrt(Σ_i ||g_i||^2)
    total_norm = np.sqrt(
        sum(np.sum(g ** 2) for g in gradients)
    )

    # 如果梯度范数超过阈值,等比例缩放所有梯度
    if total_norm > max_norm:
        # scale = C / ||g||,确保裁剪后范数 = C
        scale = max_norm / total_norm
        clipped_gradients = [g * scale for g in gradients]
        was_clipped = True
    else:
        clipped_gradients = gradients  # 不超限则不裁剪
        was_clipped = False

    return clipped_gradients, total_norm, was_clipped


def demo_gradient_clipping():
    """
    演示梯度裁剪的效果
    对比: 无裁剪 vs 有裁剪,观察梯度范数的变化
    """
    np.random.seed(42)
    print("梯度裁剪效果演示:")
    print("=" * 60)

    # 模拟10个参数的梯度(如一层RNN的W_xh, W_hh, b_h等)
    # 模拟正常梯度和爆炸梯度两种场景
    scenarios = {
        "正常梯度": [
            np.random.randn(4, 3) * 0.5,    # W_xh: 4×3
            np.random.randn(4, 4) * 0.5,    # W_hh: 4×4
            np.random.randn(4) * 0.5,        # b_h: 4
        ],
        "梯度爆炸": [
            np.random.randn(4, 3) * 50.0,   # W_xh 梯度异常大
            np.random.randn(4, 4) * 50.0,   # W_hh 梯度异常大
            np.random.randn(4) * 50.0,       # b_h 梯度异常大
        ],
    }

    thresholds = [1.0, 5.0, 10.0]  # 不同裁剪阈值

    for name, grads in scenarios.items():
        print(f"\n场景: {name}")
        # 计算原始总范数
        raw_norm = np.sqrt(sum(np.sum(g ** 2) for g in grads))
        print(f"  原始梯度总范数: {raw_norm:.2f}")

        for C in thresholds:
            clipped, new_norm, was = gradient_clip(grads, max_norm=C)
            # 裁剪后实际范数
            clipped_norm = np.sqrt(
                sum(np.sum(g ** 2) for g in clipped)
            )
            action = "已裁剪" if was else "未超限"
            print(f"    阈值C={C}: {action}, 裁剪后范数={clipped_norm:.2f}")
            if was:
                # 验证方向不变:裁剪前后每个梯度只差一个缩放因子
                scale = C / raw_norm
                print(f"      缩放因子: {scale:.4f} (所有梯度同比例缩放,方向不变)")


def per_param_clipping(gradients, max_norm=5.0):
    """
    逐参数梯度裁剪(备选方案)
    对每个参数的梯度单独裁剪,而非全局范数
    参数:
        gradients: 梯度列表
        max_norm: 每个参数梯度的范数上限
    返回:
        clipped_gradients: 裁剪后的梯度
    """
    clipped = []
    for g in gradients:
        g_norm = np.sqrt(np.sum(g ** 2))
        if g_norm > max_norm:
            g = g * (max_norm / g_norm)  # 单独缩放
        clipped.append(g)
    return clipped


demo_gradient_clipping()

print(f"\n\n梯度裁剪的关键属性:")
print(f"  1. 不改变梯度方向,只限制范数长度")
print(f"  2. 所有参数梯度等比例缩放")
print(f"  3. 常见阈值: C = 1.0~10.0(任务和网络深度相关)")
print(f"  4. LSTM训练中,默认阈值通常为5.0")
梯度裁剪公式\(\text{如果 } \|\mathbf{g}\| > C, \text{ 则 } \mathbf{g}' = \frac{C}{\|\mathbf{g}\|} \cdot \mathbf{g}\)
大白话 梯度裁剪就是给优化器装上"限速器"——不管BPTT算出来的梯度有多大,实际更新时最多只能走C这么远。好比开车在盘山公路上,不管油门踩多深,车速都被限制在安全范围内。方向还是那个方向(往低处走),但步子不会太大,不会冲下悬崖。

什么用(AI关联):①梯度裁剪是几乎所有RNN(包括LSTM/GRU)训练的标配——无论序列多长、谱半径多大,训练都不会因梯度爆炸而发散。②在Transformer训练中也广泛使用(通常与warmup学习率调度配合)。③在一些深度CNN的训练中也有应用,尤其是深层网络训练的早期阶段。④梯度裁剪配合适当的初始化几乎是RNN训练的最低配置——缺了它,长序列训练很容易出现NaN。

哪些坑(缺点):①梯度裁剪只解决爆炸不解决消失——对梯度消失无能为力。②阈值C设太小会过度限制有效学习信号,等价于使用过小的学习率,训练收敛变慢。③阈值C设太大则失去保护作用——需要在"安全"和"高效"之间权衡。④全局L2裁剪对所有参数一视同仁,但在深度网络中不同层的梯度规模可能差异很大(底层层梯度通常更小),可以结合层自适应裁剪策略。⑤梯度裁剪与动量优化器(如Adam)共用时需要额外注意——动量累积可能使有效更新步长仍超过裁剪阈值。

三、LSTM/GRU——从架构层面解决梯度消失

是什么(定义):LSTM(长短期记忆网络)和GRU(门控循环单元)通过引入门控机制从根本上改变了信息在时间步之间的流动方式——核心思想是将RNN中的"乘法连接"(h_t = tanh(W_hh·h_{t-1} + ...))改为"加法连接+门控选择"。以LSTM为例,细胞状态C_t通过线性加法(而非非线性变换)在时间步间传递:C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃t。这使得 ∂C_t/∂C{t-1} = diag(f_t),而不是 diag(tanh')·W_hh^T——梯度不再经历矩阵连乘的非线性衰减。

为什么(原理):LSTM解决梯度消失的核心在于"恒等映射通道"(Identity Highway)。当遗忘门f_t接近1且输入门i_t接近0时,C_t ≈ C_{t-1}——信息无损地从t-1传递到t,梯度也因此无损地反向传播。即使f_t不完全为1,只要f_t的值不太小(如>0.5),梯度衰减的速度也比RNN中的指数衰减慢得多——因为衰减因子每步最多为f_t而非tanh'·谱半径的乘积。GRU通过类似机制实现:h_t = (1-z_t)⊙h_{t-1} + z_t⊙h̃_t,其中(1-z_t)项提供梯度直通通道。两者都利用了门控网络(Sigmoid输出)来学习"哪些信息应该跨时间步无障碍传递"。

怎么做(实现)

import numpy as np

# ========================================
# LSTM/GRU 梯度高速公路 —— 从架构上解决梯度消失
# LSTM: C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t   → 加法连接
# GRU:  h_t = (1-z_t)⊙h_{t-1} + z_t⊙h̃_t   → 对偶门控
# ========================================

class LSTMCell:
    """
    LSTM单元 —— 通过细胞状态的加法连接解决梯度消失
    三个门: 遗忘门 f_t, 输入门 i_t, 输出门 o_t
    独立细胞状态 C_t 作为梯度高速公路
    """
    def __init__(self, input_size, hidden_size):
        # 四个权重矩阵(三个门 + 候选细胞状态)
        combined_size = hidden_size + input_size
        self.W_f = np.random.randn(hidden_size, combined_size) * 0.01
        self.b_f = np.zeros(hidden_size) + 1.0  # 遗忘门偏置初始化为1,促进梯度传递

        self.W_i = np.random.randn(hidden_size, combined_size) * 0.01
        self.b_i = np.zeros(hidden_size)

        self.W_o = np.random.randn(hidden_size, combined_size) * 0.01
        self.b_o = np.zeros(hidden_size)

        self.W_c = np.random.randn(hidden_size, combined_size) * 0.01
        self.b_c = np.zeros(hidden_size)

    def _sigmoid(self, x):
        return 1.0 / (1.0 + np.exp(-np.clip(x, -50, 50)))

    def forward(self, x_t, h_prev, C_prev):
        """
        LSTM单步前向传播
        返回:
            h_t: 隐藏状态
            C_t: 细胞状态(梯度高速公路的载体)
            gates: 门控值和梯度的分析信息
        """
        concat = np.concatenate([h_prev, x_t])

        # 遗忘门: 决定遗忘多少旧细胞状态
        f_t = self._sigmoid(np.dot(self.W_f, concat) + self.b_f)
        # 输入门: 决定写入多少新信息
        i_t = self._sigmoid(np.dot(self.W_i, concat) + self.b_i)
        # 输出门: 决定暴露多少细胞状态
        o_t = self._sigmoid(np.dot(self.W_o, concat) + self.b_o)
        # 候选细胞状态: 新信息的候选内容
        C_tilde = np.tanh(np.dot(self.W_c, concat) + self.b_c)

        # ---- 核心: 加法连接 ----
        # C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
        # ∂C_t/∂C_{t-1} = diag(f_t) —— 梯度直通!不经过W_hh矩阵乘法
        C_t = f_t * C_prev + i_t * C_tilde

        # 隐藏状态: o_t ⊙ tanh(C_t)
        h_t = o_t * np.tanh(C_t)

        gates = {
            'f_t': f_t, 'i_t': i_t, 'o_t': o_t,
            'C_tilde': C_tilde, 'C_t': C_t
        }
        return h_t, C_t, gates


def demo_lstm_gradient_highway():
    """
    展示LSTM中梯度高速公路的效果
    对比RNN和LSTM在相同序列长度下的梯度衰减
    """
    print("LSTM 梯度高速公路演示:")
    print("=" * 60)

    np.random.seed(42)
    input_size, hidden_size = 3, 4
    lstm = LSTMCell(input_size, hidden_size)

    # 模拟长序列处理
    seq = np.random.randn(30, input_size)  # 30个时间步
    h = np.zeros(hidden_size)
    C = np.zeros(hidden_size)

    f_values = []

    for t in range(30):
        h, C, gates = lstm.forward(seq[t], h, C)
        # 记录每个时间步遗忘门的平均值(梯度衰减因子)
        f_avg = np.mean(gates['f_t'])
        f_values.append(f_avg)

    print(f"遗忘门 f_t 在每个时间步的平均值:")
    print(f"  前5步:  {[f'{v:.3f}' for v in f_values[:5]]}")
    print(f"  后5步:  {[f'{v:.3f}' for v in f_values[-5:]]}")
    print(f"  30步内f_t的均值: {np.mean(f_values):.3f}")

    # 对比: RNN的梯度衰减 vs LSTM的梯度衰减
    print(f"\n梯度衰减对比(30个时间步后):")
    # RNN: λ^30, 假设λ≈0.8(tanh导数 × W_hh谱半径)
    rnn_decay = 0.8 ** 30
    # LSTM: 平均f_t ≈ 0.7,衰减因子 ≈ f_t_avg^30
    lstm_decay = np.mean(f_values) ** 30
    # LSTM最佳情况: f_t ≈ 0.95(遗忘门偏置=1的初始状态)
    lstm_best = 0.95 ** 30

    print(f"  简单RNN  (λ=0.8):    衰减到 {rnn_decay:.2e}")
    print(f"  LSTM     (f_t均值):    衰减到 {lstm_decay:.2e}")
    print(f"  LSTM最佳 (f_t≈0.95):  衰减到 {lstm_best:.4f}")
    print(f"  → LSTM通过门控机制将梯度衰减速度从 λ^T 降为 f_t^T")
    print(f"  → f_t 由网络学习控制,可动态调整衰减率")

    # 梯度高速公路可视化
    print(f"\nLSTM细胞状态的梯度高速公路:")
    print(f"  C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t")
    print(f"  ∂C_t/∂C_{t-1} = diag(f_t)  ← 只有f_t的对角矩阵,不涉及W_hh")
    print(f"  如果 f_t ≈ 1, i_t ≈ 0 → C_t ≈ C_{t-1} → 梯度无损传递!")
    print(f"  如果 f_t = 0 → C_t = i_t ⊙ C̃_t → 重置记忆,接受全新信息")
    print(f"  → f_t 为1时是'高速公路',f_t 接近0时是'重置出口'")


def demo_gru_gradient_highway():
    """
    GRU中梯度高速公路的分析
    """
    print(f"\n\nGRU 梯度高速公路分析:")
    print("=" * 60)

    # GRU的核心: h_t = (1-z_t)⊙h_{t-1} + z_t⊙h̃_t
    print(f"  GRU: h_t = (1-z_t) ⊙ h_{t-1} + z_t ⊙ h̃_t")
    print(f"  ∂h_t/∂h_{t-1} ≈ (1-z_t) + z_t · ∂h̃_t/∂h_{t-1}")
    print(f"\n  当 z_t ≈ 0 时:")
    print(f"    h_t ≈ h_{t-1}  → 信息无损传递")
    print(f"    ∂h_t/∂h_{t-1} ≈ I  → 梯度无损反向传播")
    print(f"  当 z_t ≈ 1 时:")
    print(f"    h_t ≈ h̃_t  → 写入全新信息")
    print(f"    ∂h_t/∂h_{t-1} ≈ ∂h̃_t/∂h_{t-1}  → 梯度经由候选状态传递")


demo_lstm_gradient_highway()
demo_gru_gradient_highway()
LSTM细胞状态更新\(\mathbf{C}_t = \mathbf{f}_t \odot \mathbf{C}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{C}}_t\)
LSTM遗忘门初始化技巧\(\mathbf{b}_f^{\text{init}} = 1.0, \quad \mathbf{f}_t = \sigma(\mathbf{W}_f \cdot [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f)\)
大白话 普通RNN每个时间步都在传话,信息经过几十次tanh和矩阵乘法后早就面目全非了。LSTM给信息留了一条"高速路"——细胞状态C_t大部分时间是直通的(加法),只在需要时才加上新信息或擦除旧信息。这就像传话游戏中给每个玩家一本原始手稿——他们可以在上面添加笔记(输入门),划掉不需要的内容(遗忘门),但原始信息不会因为反复转述而失真。梯度也因此知道"沿着这条路能找到最早的那些信息",不会消失在半路。

什么用(AI关联):①LSTM可以将有效序列记忆从RNN的20-30步扩展到100-200步甚至更长,是NLP深度学习的基石性架构。②GRU结构更简洁,参数量约为LSTM的75%,训练速度更快,在大部分任务上与LSTM效果相当。③两者都是Seq2Seq(编码器-解码器)架构的标准组件,广泛应用于机器翻译、文本摘要、语音识别等任务。④现代实践中,Transformer已取代LSTM/GRU成为NLP的主流架构,但LSTM/GRU在序列较短、数据量较小、计算资源受限的场景下仍然是实用选择。

哪些坑(缺点):①LSTM参数多(4个门,每个门一组权重),在小数据集上容易过拟合,训练速度也比GRU慢。②即使有梯度高速公路,LSTM/GRU仍无法处理任意长度的序列——太长的序列(>500步)梯度仍会逐渐衰减,且计算成本线性增长。③LSTM/GRU仍然需要串行计算(每一步依赖前一步),无法像Transformer一样并行化训练。④遗忘门偏置初始化不当(如默认初始化为0)会导致训练极慢——因为网络一开始就倾向于"遗忘",梯度传播受阻。⑤门控机制本身也会饱和——Sigmoid输出接近0或1时,门控的梯度也接近零,导致门控参数更新困难。

四、训练技巧层面的解决方案

是什么(定义):除了梯度裁剪(工程层面)和LSTM/GRU(架构层面),还有一系列训练技巧可以从旁辅助稳定RNN的训练:①权重初始化策略——Xavier(Glorot)初始化或正交初始化,确保前向和反向信号的方差在网络各层保持稳定。②层归一化(Layer Normalization)——在特征维度上对每层的激活值进行标准化,缓解内部协变量偏移(Internal Covariate Shift),稳定训练过程中的梯度流动。③截断BPTT——将BPTT的反向传播限制在固定的K步内,既减少计算和内存开支,也避免了梯度在过长链路上的极端衰减或增长。④残差连接——在深层RNN的层间引入跳跃连接(Skip Connection),为梯度提供跨层的直通路径。

为什么(原理):这四类技巧分别针对梯度问题的不同侧面:Xavier初始化让网络从"梯度友好"的起点开始训练(信号不被初始权重放大或缩小);层归一化约束每层的激活值分布,防止隐藏状态在不同时间步发散;截断BPTT直接限制连乘的步数,是一种"截断链式衰减"的简单粗暴方法;残差连接借鉴ResNet的思想,将h^l = RNN(h^{l-1})改为h^l = RNN(h^{l-1}) + h^{l-1},使得∂h^l/∂h^{l-1}中包含了恒等映射的直通项——梯度不会在层间完全消失。

怎么做(实现)

import numpy as np

# ========================================
# 训练技巧1: Xavier/Glorot 权重初始化
# 保持前向和反向传播中信号的方差稳定
# ========================================

def xavier_init(shape, fan_in, fan_out, activation='tanh'):
    """
    Xavier/Glorot 初始化
    参数:
        shape: 权重矩阵的形状 (out_dim, in_dim)
        fan_in: 输入维度
        fan_out: 输出维度
        activation: 激活函数类型,影响缩放因子
    返回:
        初始化后的权重矩阵
    公式: W ~ Uniform(-limit, limit), limit = sqrt(6/(fan_in + fan_out))
    """
    if activation == 'tanh' or activation == 'sigmoid':
        # Xavier 归一化: 适用于tanh/sigmoid
        limit = np.sqrt(6.0 / (fan_in + fan_out))
    elif activation == 'relu':
        # He 初始化: 适用于ReLU系列激活
        limit = np.sqrt(6.0 / fan_in)
    else:
        limit = np.sqrt(6.0 / (fan_in + fan_out))

    # 在 [-limit, limit] 范围内均匀采样
    W = np.random.uniform(-limit, limit, shape)
    return W


def orthogonal_init(shape):
    """
    正交初始化 —— 适用于RNN的W_hh权重
    正交矩阵的特征值绝对值全为1,谱半径=1
    这使得梯度在时间维度上理论上不会指数衰减
    参数:
        shape: 权重矩阵的形状 (hidden_size, hidden_size)
    返回:
        初始化后的正交矩阵
    """
    # 先生成随机高斯矩阵
    W = np.random.randn(*shape)
    # QR分解 → Q是正交矩阵
    Q, R = np.linalg.qr(W)
    # 确保对角线为正(避免符号翻转的不确定性)
    d = np.diag(np.sign(np.diag(R)))
    return Q @ d


# --- 演示 ---
print("权重初始化策略对比:")
print("=" * 60)

hidden_size = 64
shape = (hidden_size, hidden_size)

# Xavier初始化
W_xavier = xavier_init(shape, hidden_size, hidden_size, 'tanh')
# 正交初始化
W_orth = orthogonal_init(shape)
# 不好的初始化: 太小
W_small = np.random.randn(*shape) * 0.001
# 不好的初始化: 太大
W_large = np.random.randn(*shape) * 100.0

# 分析谱半径(决定BPTT梯度衰减的关键)
for name, W in [("Xavier", W_xavier), ("正交", W_orth),
                ("太小(×0.001)", W_small), ("太大(×100)", W_large)]:
    ev = np.linalg.eigvals(W)
    sr = np.max(np.abs(ev))
    print(f"  {name}: 谱半径 = {sr:.4f}")
    if sr < 0.5:
        print(f"    → 梯度会沿时间步快速消失")
    elif sr > 1.5:
        print(f"    → 梯度可能爆炸")
    else:
        print(f"    → 梯度传播较为稳定")

# 正交初始化的特殊性质
print(f"\n正交初始化验证:")
print(f"  特征值全为1? {np.allclose(np.abs(np.linalg.eigvals(W_orth)), 1.0, atol=1e-6)}")
print(f"  → 正交矩阵的列是单位正交的,谱半径精确为1")
print(f"  → 对于简单RNN的W_hh,正交初始化是最理想的起点")
import numpy as np

# ========================================
# 训练技巧2: 层归一化 (Layer Normalization)
# 在特征维度标准化,稳定隐藏状态的分布
# ========================================

def layer_norm(x, gamma=None, beta=None, eps=1e-5):
    """
    层归一化实现
    对每个样本的特征维度(最后一维)进行标准化
    参数:
        x: 输入,shape (..., features)
        gamma: 缩放参数(可学习),shape (features,)
        beta: 偏移参数(可学习),shape (features,)
        eps: 防止除零的小常数
    返回:
        normalized: 标准化后的输出
    公式: y = gamma * (x - mean) / sqrt(var + eps) + beta
    """
    # 沿特征维度(最后一维)计算均值和方差
    mean = np.mean(x, axis=-1, keepdims=True)
    var = np.var(x, axis=-1, keepdims=True)

    # 标准化: (x - μ) / √(σ² + ε)
    x_norm = (x - mean) / np.sqrt(var + eps)

    # 仿射变换(可学习的缩放和偏移)
    if gamma is None:
        gamma = np.ones(x.shape[-1])  # 默认不缩放
    if beta is None:
        beta = np.zeros(x.shape[-1])  # 默认不偏移

    normalized = gamma * x_norm + beta
    return normalized, mean, var


def demo_layer_norm_effect():
    """
    展示层归一化对RNN隐藏状态的影响
    模拟: 有层归一化 vs 无层归一化,隐藏状态在时间步中的变化
    """
    np.random.seed(42)
    print("层归一化效果演示:")
    print("=" * 60)

    hidden_size = 8
    seq_len = 50

    # 模拟隐藏状态在50个时间步中的演变
    h = np.random.randn(hidden_size) * 0.1  # 初始隐藏状态

    h_no_norm_history = []  # 无层归一化
    h_with_norm_history = []  # 有层归一化
    gamma = np.ones(hidden_size)
    beta = np.zeros(hidden_size)

    for t in range(seq_len):
        # -- 无层归一化 --
        # h_t = tanh(W_hh·h_{t-1}), 模拟W_hh为随机矩阵
        W_hh = np.random.randn(hidden_size, hidden_size) * 0.3
        h_no_norm = np.tanh(np.dot(W_hh, h))
        h_no_norm_history.append(np.linalg.norm(h_no_norm))

        # -- 有层归一化 --
        # 先做层归一化,再做tanh
        h_norm, mean, var = layer_norm(h, gamma, beta)
        h_with_norm = np.tanh(np.dot(W_hh, h_norm))
        h_with_norm_history.append(np.linalg.norm(h_with_norm))

        # 更新隐藏状态(两者独立演化)
        h = np.tanh(np.dot(W_hh, h))  # 基线:无层归一化

    # 统计对比
    print(f"隐藏状态L2范数在50个时间步中的统计:")
    print(f"  无层归一化:")
    print(f"    均值: {np.mean(h_no_norm_history):.4f}")
    print(f"    标准差: {np.std(h_no_norm_history):.4f}")
    print(f"    最小值: {np.min(h_no_norm_history):.6f}")
    print(f"    → 隐藏状态范数随时间剧烈波动,可能衰减至接近零")
    print(f"  有层归一化:")
    print(f"    均值: {np.mean(h_with_norm_history):.4f}")
    print(f"    标准差: {np.std(h_with_norm_history):.4f}")
    print(f"    最小值: {np.min(h_with_norm_history):.6f}")
    print(f"    → 隐藏状态范数更稳定,梯度传播条件更好")

    # 解释
    print(f"\n层归一化 = 批归一化的序列友好替代:")
    print(f"  批归一化: 沿 batch 维度标准化 → 依赖batch统计量,RNN序列长度不同时不稳定")
    print(f"  层归一化: 沿 feature 维度标准化 → 每个样本独立,不受batch和序列长度影响")
    print(f"  两者都通过标准化缓解内部协变量偏移,使梯度在不同时间步之间更均匀")


def truncated_bptt_demo(seq_len=100, truncation=20):
    """
    截断BPTT 演示
    将BPTT反向传播限制在truncation步内
    """
    np.random.seed(42)
    print(f"\n\n截断BPTT 演示:")
    print("=" * 60)
    print(f"  全BPTT: 反向传播 {seq_len} 步")
    print(f"  截断BPTT: 反向传播 {truncation} 步")
    print(f"  计算量减少: {(1 - truncation/seq_len)*100:.0f}%")
    print(f"  内存减少: 约 {(1 - truncation/seq_len)*100:.0f}%")
    print(f"\n  截断策略:")
    print(f"    将序列分成 {seq_len // truncation} 个块")
    print(f"    每块内做完整BPTT({truncation}步)")
    print(f"    块之间: 隐藏状态前向传递,但梯度不跨块反向传播")
    print(f"  → 牺牲超长距离依赖,换取计算效率和训练稳定性")

demo_layer_norm_effect()
truncated_bptt_demo()
Xavier初始化公式\(\mathbf{W} \sim \mathcal{U}\left(-\sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}},\ \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}\right)\)
层归一化\(\text{LayerNorm}(\mathbf{x}) = \gamma \cdot \frac{\mathbf{x} - \mu}{\sqrt{\sigma^2 + \varepsilon}} + \beta, \quad \mu = \frac{1}{H}\sum_{i=1}^{H} x_i, \quad \sigma^2 = \frac{1}{H}\sum_{i=1}^{H} (x_i - \mu)^2\)
大白话 Xavier初始化 = 给网络"一个公平的起跑线"——起始时所有层的前向信号和反向梯度规模差不多,不会一开始就有的层信号巨大、有的层信号微茫。层归一化 = 给每个时间步的隐藏状态"整理仪容"——不管前面怎么折腾,每个隐藏状态进入下一时间步前都被标准化到均值0、方差1,防止某个时间步的极端值污染后续传播。正交初始化 = 让W_hh的特征值全为1——梯度在时间维度上理论上不增不减,是最理想的起点。

什么用(AI关联):①Xavier初始化是几乎所有深度学习框架的默认初始化策略(如PyTorch中Linear层的默认权重初始化就是Kaiming/Xavier的变体)。②层归一化不仅在RNN中广泛使用,还是Transformer架构的核心组件——Transformer中每个子层(自注意力、FFN)之后都有层归一化。③截断BPTT是训练长序列RNN(如语言模型在长文本上训练)的标准做法,与状态持久化(State Persistence)配合可以处理几乎无限长的序列。④残差连接 + 深层RNN + 层归一化的组合可以让RNN堆叠到更深(如4-6层),显著提升表达能力。

哪些坑(缺点):①Xavier初始化专门针对tanh/sigmoid设计——如果使用ReLU激活函数,必须切换到He(Kaiming)初始化,否则前向信号会逐层衰减。②层归一化虽然对序列长度无依赖,但在一些CNN架构中批归一化的效果更好——因为CNN中batch维度的统计信息比单样本的特征维度统计信息更稳定。③截断BPTT的截断长度K需要根据任务调整:K太小则长距离依赖完全丢失(语言模型的"远距离主语-谓语一致性"等任务会失败),K太大则计算和内存开销接近完整BPTT。④残差连接要求输入和输出维度相同——在RNN的层间维数变化时需要额外的线性投影(1×1卷积或全连接层)。

概念关系图谱

概念核心含义与AI的关系关联概念
BPTT沿时间展开的链式反向传播RNN训练的基础算法截断BPTT、Teacher Forcing
梯度消失早期时间步梯度趋近于零导致RNN无法学习长距离依赖梯度爆炸、LSTM、GRU
梯度爆炸梯度范数指数增长导致参数更新震荡甚至NaN梯度裁剪、权重正则化
梯度裁剪限制梯度范数上限RNN训练稳定性的基础保障权重裁剪、学习率调度
LSTM带三个门控和独立细胞状态的RNN变体克服梯度消失的经典架构GRU、门控机制、细胞状态
GRU带两个门控的简化LSTM参数更少、训练更快的替代方案LSTM、更新门、重置门
截断BPTT限制BPTT展开步数平衡计算效率和长距离依赖TBPTT、分块训练、状态持久化
层归一化沿特征维度标准化激活值稳定RNN/Transformer的训练批量归一化、RMSNorm、Pre-LN
Xavier初始化基于输入输出维度缩放的权重初始化防止前向/反向信号的初始衰减He初始化、正交初始化
正交初始化权重矩阵初始化为正交矩阵谱半径=1,梯度理论上不衰减谱半径、恒等矩阵、QR分解
残差连接跨层的跳跃连接(Skip Connection)缓解深层网络中的梯度消失ResNet、恒等映射、梯度高速公路
谱半径矩阵最大特征值的绝对值决定BPTT中梯度是消失还是爆炸特征值、稳定性分析

重点答疑

Q1: 为什么LSTM能缓解梯度消失而GRU也可以?

两者都通过改变信息在时间步之间的流动方式来解决梯度消失——核心是从RNN的"非线性乘法连接"改为"线性加法连接+门控选择"。

LSTM的细胞状态更新为 C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃t,这是一个加法连接(而非h_t = tanh(W·h{t-1})这样的非线性乘法)。对C_{t-1} 求偏导得到 ∂C_t/∂C_{t-1} = diag(f_t),不涉及W_hh矩阵的连乘,也不经过tanh的非线性压缩。即使f_t不完全为1,衰减也只取决于f_t的值(通常在0.5-0.95),而不是谱半径和tanh导数的乘积。

GRU通过类似的机制实现:h_t = (1-z_t) ⊙ h_{t-1} + z_t ⊙ h̃t。对h{t-1}求偏导得到 ∂h_t/∂h_{t-1} ≈ (1-z_t) + z_t·∂h̃t/∂h{t-1}。第一项(1-z_t)提供了直通通道——当z_t接近0时,梯度几乎无损传递。

两者的设计哲学一致:用门控网络(Sigmoid输出在0-1之间)学习"哪些信息可以无障碍传递",代替RNN中"所有信息都要经过非线性压缩再传递"的硬性约束。

Q2: 梯度裁剪设多少合适?设太小会怎样?

常见的梯度裁剪阈值范围是1.0到10.0(L2范数)。PyTorch中clip_grad_norm_的默认参数通常设为5.0左右。实际选择取决于任务和网络结构:

  • 深层RNN(>4层):建议更小的阈值(1.0-3.0),因为梯度在层间也会衰减/放大
  • 浅层RNN/LSTM(1-2层):5.0-10.0通常是安全的
  • Transformer训练:通常设为1.0(配合Adam优化器和warmup)
  • 初始训练阶段:可以先不设裁剪,观察梯度的L2范数在正常训练步中的范围,然后设为略高于正常水平的值

设太小的后果:①有效学习率降低——梯度被压缩后,参数更新步长变小,收敛速度减慢。②在一些需要"大步跳出"的情况(如损失面有狭窄沟壑)中可能导致收敛到次优解。③极端情况下(C<0.1),训练几乎停滞。④但临界情况是——"设太小"的危害通常远小于"不设裁剪导致的爆炸",所以实践中宁可偏小也不要偏大。

Q3: 截断BPTT和梯度裁剪有什么区别?

两者针对的是不同的问题,解决的机制也不同:

  • 梯度裁剪针对"梯度爆炸":限制的是梯度的范数上限。它在每次反向传播完成后、参数更新之前操作,只影响更新步长,不改变反向传播的过程。梯度裁剪不解决梯度消失——它只限制了上限。
  • 截断BPTT针对"计算和内存开销"以及间接缓解梯度消失/爆炸:限制的是反向传播在时间维度上的步数。它改变了反向传播本身——只反向传播最近的K个时间步,更早期的梯度直接截断为0。截断BPTT既能减少计算量,也能避免梯度在过长链路上的极端衰减,但同时牺牲了超长距离依赖的学习能力。

一个形象类比:梯度裁剪是"限速器"——不管路程多远,车速有上限。截断BPTT是"只走最后10公里"——更远的地方就不走了。两者可以(而且通常应该)同时使用。

Q4: 什么时候用LSTM而不是GRU?

经验法则如下:

  • 先试GRU:参数少、训练快、大多数任务效果相当。尤其推荐在小数据集(<10万样本)、移动端部署、需要快速迭代的实验阶段使用。
  • 换成LSTM的情况:①任务对精细记忆管理要求高(如需要精确记住序列中某个特定位置的数值)。②数据集足够大(>100万样本),LSTM更多的参数不会导致过拟合。③GRU实验后效果不满足要求——LSTM的独立细胞状态和输出门可能在特定任务中带来额外收益。④需要与已有LSTM模型兼容(很多经典的预训练模型和代码库基于LSTM)。
  • 学术传统:NLP社区中使用LSTM的论文远多于GRU,部分原因是LSTM历史更久(1997 vs 2014),研究积累更多。但在实际工程中,GRU经常是不逊于LSTM的选择。

Q5: 如何诊断RNN训练中出现了梯度消失还是爆炸?

梯度消失的典型信号:①训练损失在初期下降后迅速停滞,几乎不再变化。②早期时间步的参数梯度远小于后期时间步(可以打印梯度范数按时间步的分布)。③模型对所有输入生成几乎相同的输出(因为早期信息几乎没有被学到)。④tanh激活值的分布集中在饱和区(-1和1附近),导数接近零。

梯度爆炸的典型信号:①训练损失剧烈震荡或突然变为NaN。②参数更新步长异常大。③权重矩阵的范数在训练过程中持续增大。④模型输出在相邻迭代步之间差异极大。

实用的诊断方法:在每个训练步记录梯度的L2范数,画出其随时间步(epoch/iteration)的变化曲线。如果范数持续增长到1e3以上→梯度爆炸;如果范数持续衰减到1e-6以下→梯度消失;如果范数在合理范围内波动→训练健康。

章节单词汇总

英文音标术语/释义
Backpropagation Through Time/ˌbækˌprɑpəˈɡeɪʃən θruː taɪm/通过时间的反向传播,RNN训练算法
Gradient Vanishing/ˈɡreɪdiənt ˈvænɪʃɪŋ/梯度消失,梯度在反向传播中指数衰减
Gradient Exploding/ˈɡreɪdiənt ɪkˈsploʊdɪŋ/梯度爆炸,梯度在反向传播中指数增长
Gradient Clipping/ˈɡreɪdiənt ˈklɪpɪŋ/梯度裁剪,限制梯度范数上限的技术
Truncated BPTT/ˈtrʌŋkeɪtɪd/截断BPTT,限制反向传播时间步数
Jacobian Matrix/dʒəˈkoʊbiən ˈmeɪtrɪks/雅可比矩阵,即∂h_t/∂h_{t-1}
Layer Normalization/ˈleɪər ˌnɔrməlaɪˈzeɪʃən/层归一化,沿特征维度标准化
Xavier Initialization/ˈzeɪviər ɪˌnɪʃəlaɪˈzeɪʃən/Xavier初始化,基于维度的权重初始化
Orthogonal Initialization/ɔrˈθɑɡənəl/正交初始化,权重矩阵正交化
Spectral Radius/ˈspɛktrəl ˈreɪdiəs/谱半径,矩阵最大特征值的绝对值
State Space Model/steɪt speɪs ˈmɑdəl/状态空间模型
Residual Connection/rɪˈzɪdʒuəl kəˈnɛkʃən/残差连接,跨层的跳跃连接
Eigenvalue/ˈaɪɡənˌvælju/特征值,矩阵的标量特性
Norm Threshold/nɔrm ˈθrɛʃhoʊld/范数阈值,梯度裁剪的上限值
Forget Gate Bias/fərˈɡɛt ɡeɪt ˈbaɪəs/遗忘门偏置,初始化为正值以促进梯度传递
Internal Covariate Shift/ɪnˈtɜrnəl koʊˈvɛriət ʃɪft/内部协变量偏移,深度网络各层输入分布变化

面试练习

Q1 [单选] BPTT中梯度消失的根本原因是什么?

  • A. 神经网络层数太多
  • B. 激活函数选择不当
  • C. Jacobian矩阵的特征值小于1,时间步连乘导致指数衰减
  • D. 批大小设置不当
解答:BPTT中梯度需要在时间维度上连乘Jacobian矩阵∂h_t/∂h_{t-1}。当Jacobian的谱半径(最大特征值)<1时,连乘T次后以指数速度(λ^T)趋近于零,即梯度消失。

Q2 [单选] 梯度裁剪的直接作用是?

  • A. 加速训练收敛
  • B. 防止梯度爆炸,保持训练稳定
  • C. 消除梯度消失问题
  • D. 减少模型参数
解答:梯度裁剪通过限制梯度范数的上限,防止梯度爆炸导致的参数更新过大或NaN。不改变梯度方向,也不解决梯度消失问题。

Q3 [多选] 以下哪些方法可以缓解RNN的梯度消失问题?

  • A. 使用LSTM或GRU替代简单RNN
  • B. 使用截断BPTT
  • C. 使用层归一化(Layer Normalization)
  • D. 增加批大小
解答:LSTM/GRU通过门控机制从架构层面缓解;截断BPTT限制传播步数(间接缓解);层归一化稳定每层激活值分布。增加批大小不能直接解决梯度消失。

Q4 [单选] LSTM中哪一条是梯度传播的"高速公路"?

  • A. h_t = o_t ⊙ tanh(C_t)
  • B. C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
  • C. f_t = σ(W_f·[h_{t-1}, x_t])
  • D. i_t = σ(W_i·[h_{t-1}, x_t])
解答:细胞状态C_t的更新是加法连接(而非乘法连接)。∂C_t/∂C_{t-1} = diag(f_t),不涉及W_hh矩阵的连乘,也不经过tanh压缩。当f_t≈1时梯度几乎无损传递。

Q5 [单选] 为什么LSTM的遗忘门偏置b_f通常初始化为正值(如1.0)?

  • A. 提高模型复杂度
  • B. 让网络在训练初期倾向于"保留信息",促进梯度传递
  • C. 节省内存
  • D. 它是LSTM的默认参数,没有特别原因
解答:b_f初始化为1.0时,f_t ≈ σ(1) ≈ 0.73,意味着训练初期网络倾向于保留过去的信息。这使得梯度在训练早期更容易传播,显著加快收敛。若b_f初始化为0,f_t ≈ σ(0) = 0.5,初期就会丢失一半信息。

Q6 [多选] 以下关于梯度裁剪的说法,哪些是正确的?

  • A. 梯度裁剪不改变梯度方向,只限制其范数长度
  • B. 常见裁剪阈值C的范围是1.0到10.0
  • C. 梯度裁剪可以同时解决梯度消失和爆炸
  • D. 阈值设太小会降低有效学习率
解答:梯度裁剪只解决爆炸不解决消失(C项错误)。A、B、D正确。

Q7 [单选] 层归一化(Layer Normalization)和批量归一化(Batch Normalization)在RNN中的主要区别是什么?

  • A. 层归一化更快
  • B. 批量归一化效果更好
  • C. 层归一化沿特征维度标准化,不受batch大小和序列长度影响
  • D. 两者完全相同
解答:层归一化对每个样本的特征维度独立标准化,不依赖batch维度的统计量。这使得它在RNN中比批归一化更适用——因为RNN的序列长度可变,不同时间步的统计量也不同,跨batch的批归一化统计量不准确。

Q8 [单选] 正交初始化对RNN训练有什么好处?

  • A. 减少参数数量
  • B. 正交矩阵的谱半径=1,梯度在时间维度上理论上不会指数衰减
  • C. 加快前向传播速度
  • D. 提高模型精度
解答:正交矩阵的所有特征值绝对值=1(谱半径=1),这意味着BPTT中Jacobian的连乘不会引起指数衰减或增长。它是RNN权重初始化在理论上的最优选择——让网络从一个"梯度稳定"的起点开始训练。

Q9 [多选] 以下哪种情况表明RNN训练中出现了梯度爆炸?

  • A. 训练损失突然变为NaN
  • B. 参数梯度的L2范数持续增长到1e3以上
  • C. 模型输出在相邻迭代步之间剧烈变化
  • D. 验证准确率缓慢提升
解答:A、B、C都是梯度爆炸的典型信号。D是正常的训练行为或欠拟合,与梯度爆炸无关。

Q10 [单选] GRU中哪个门控提供了与LSTM细胞状态类似的梯度高速公路?

  • A. 重置门 r_t
  • B. 更新门 z_t(通过(1-z_t)项)
  • C. 候选隐藏状态 h̃_t
  • D. 两个门都没有这个功能
解答:GRU中 h_t = (1-z_t)⊙h_{t-1} + z_t⊙h̃t,对h{t-1}求偏导= (1-z_t) + z_t·∂h̃t/∂h{t-1}。当z_t≈0时,∂h_t/∂h_{t-1}≈I(恒等矩阵),梯度直通。这与LSTM中C_t的加法高速公路原理相同。