TRPO信赖域策略优化

一句话概述

TRPO(Trust Region Policy Optimization,信赖域策略优化)是强化学习中第一个提供严格单调改进保证的策略优化算法。它通过KL散度约束限制新旧策略的差异,用自然策略梯度和共轭梯度法求解带约束的优化问题,为PPO等后续算法奠定了理论基础。

教学与演示

什么是TRPO

TRPO由John Schulman等人在2015年提出。它的核心问题是:在策略梯度方法中,多大的更新步长才是安全的?

想象你在山顶上闭着眼睛下山——步子太大可能一脚踩空滚下去,步子太小又太慢。TRPO告诉你:只要每一步的「方向」和「步长」满足某个条件,就能保证你每走一步都在往下坡走,绝对不会往上走(即策略性能单调提升)。

这个条件背后的数学理论叫做「信赖域」(Trust Region)——在一个局部的「可信区域」内,用简单的局部近似来代替复杂的全局函数是可靠的。TRPO把这个思想用到了策略优化中。

大白话 TRPO就像一个非常谨慎的登山者——每走一步前,先用数学证明这一步不会让你摔倒,然后才迈出脚。PPO则是后来被简化了的版本——不再严格证明,而是用经验规则限制步子大小。

为什么需要TRPO

在TRPO之前,策略梯度方法面临两个核心挑战:

挑战一:步长敏感。策略梯度中学习率的选择极其关键——太小收敛慢,太大策略崩溃。而且好的学习率往往因任务而异,很难泛化。

挑战二:数据浪费。REINFORCE等传统方法更新一次就得扔掉旧数据重新采样,因为旧数据来自旧策略,不能无偏估计当前策略的期望。

在TRPO之前,自然策略梯度(Natural Policy Gradient)使用Fisher信息矩阵对梯度做「修正」,一定程度上缓解了步长问题,但仍然需要手动调学习率。TRPO在此基础上加入了严格的KL散度约束和线搜索机制,实现了自动步长选择。

大白话 传统的策略梯度就像开车没有刹车——油门踩多少靠感觉,踩多了就翻车。TRPO加了一个自动刹车系统:如果这一步可能导致翻车,它就自动减小步长。

TRPO 怎么做

TRPO的核心是一个带约束的优化问题:

TRPO约束优化问题\(\begin{aligned} \max_\theta \quad & \mathbb{E}_{s \sim \rho_{\theta_{\text{old}}}, a \sim \pi_{\theta_{\text{old}}}} \left[\frac{\pi_\theta(a|s)}{\pi_{\theta_{\text{old}}}(a|s)} A^{\pi_{\theta_{\text{old}}}}(s,a) \right] \\ \text{s.t.} \quad & \mathbb{E}_{s \sim \rho_{\theta_{\text{old}}}} \left[ D_{KL}(\pi_{\theta_{\text{old}}}(\cdot|s) \| \pi_\theta(\cdot|s)) \right] \leq \delta \end{aligned}\)

其中:

    undefined

这个约束的意义是:只相信在当前策略附近的局部区域,代理目标函数是可靠的。出界了就不相信。

大白话 TRPO画了一个「信任圈」——半径是δ的KL散度。在这个圈内,代理目标函数是可靠的,可以放心用。优化时不能越界。

求解TRPO的步骤:

第一步,用泰勒展开近似目标和约束:

TRPO的二次近似\(\begin{aligned} \max_\theta \quad & g^T (\theta - \theta_{\text{old}}) \\ \text{s.t.} \quad & \frac{1}{2} (\theta - \theta_{\text{old}})^T F (\theta - \theta_{\text{old}}) \leq \delta \end{aligned}\)

其中 g 是目标函数的梯度,F 是Fisher信息矩阵(KL散度的二阶近似)。

第二步,解这个带约束的二次优化问题,得到自然梯度方向:

自然梯度更新方向\(\theta - \theta_{\text{old}} = \sqrt{\frac{2\delta}{g^T F^{-1} g}} \cdot F^{-1} g\)

F^{-1}g就是自然梯度——在参数空间中考虑了「距离度量」后的梯度方向。普通梯度 g 假设参数空间是平坦的欧几里得空间;自然梯度 F^{-1}g 假设参数空间是弯曲的统计流形(由KL散度定义距离)。

第三步,用共轭梯度法高效计算 F^{-1}g(避免显式求逆),然后用线搜索确定实际步长。

Fisher信息矩阵与自然梯度

Fisher信息矩阵的定义:

Fisher信息矩阵\(F = \mathbb{E}_{s \sim \rho, a \sim \pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \cdot \nabla_\theta \log \pi_\theta(a|s)^T \right]\)

Fisher矩阵度量了「对数概率对参数的敏感度」。在概率分布的参数空间中,欧几里得距离没有意义——比如高斯分布N(μ=0, σ=1)和N(μ=0.1, σ=1)的μ相差0.1,但N(μ=0, σ=100)和N(μ=0.1, σ=100)的μ也相差0.1,后者的「实际差别」小得多。Fisher信息矩阵正是考虑到这种「统计距离」,对梯度做修正。

什么用

TRPO虽然实现复杂(需要共轭梯度、线搜索等),但它的贡献不在于算法本身,而在于:

    undefined
大白话 TRPO虽然实现复杂,但它证明了「用KL散度约束策略更新」这条路是对的。PPO就是在TRPO的肩膀上,用更简单的方法达到了类似的效果。

哪些坑

坑点原因解决方案
实现复杂度极高需要共轭梯度、Fisher向量积、线搜索用PPO替代(95%场景);教学场景下学习TRPO原理
共轭梯度不收敛阻尼系数太小、Fisher矩阵条件数大增大阻尼系数(CG damping),增加CG迭代步数
线搜索失效策略在约束边界附近振荡调小δ,或改为指数衰减的步长
计算开销大Fisher向量积需要遍历所有样本mini-batch近似,但会引入噪声
与dropout不兼容Fisher矩阵计算需要确定性前向传播关闭dropout或使用PPO
δ值难调太小学不动,太大约束失效从0.01开始,根据KL散度监控调整

核心代码演示

下面一步步实现TRPO的关键组件:Fisher向量积、共轭梯度求解和线搜索。

"""
TRPO核心组件 - Fisher向量积、共轭梯度法
使用numpy实现,帮助理解TRPO的数学原理
"""
import numpy as np

# ===== 1. Fisher向量积 =====
def fisher_vector_product(grad_log_prob, params, vector):
    """
    计算Fisher信息矩阵与向量的乘积:F @ v
    不需要显式构建F矩阵,只需计算Hessian-vector乘积
    
    核心技巧:F = E[∇logπ * ∇logπ^T],所以
    F*v = E[∇logπ * (∇logπ^T * v)] = E[∇logπ * (∇logπ·v)]
    其中 (∇logπ·v) 是标量
    
    参数:
    - grad_log_prob: 每个样本的对数概率梯度 [N, dim]
    - params: 占位参数(仅用于接口一致)
    - vector: 待乘的向量 [dim]
    
    返回:
    - F @ v [dim]
    """
    N, dim = grad_log_prob.shape
    # (∇logπ·v): 每个样本的梯度在v方向上的投影 [N]
    grad_v_prod = grad_log_prob @ vector  # [N]
    
    # F*v = (1/N) * Σ (∇logπ * (∇logπ·v))
    result = np.zeros(dim)
    for i in range(N):
        result += grad_log_prob[i] * grad_v_prod[i]
    
    return result / N

# ===== 模拟梯度测试Fisher向量积 =====
np.random.seed(42)
N, dim = 64, 10  # 64个样本,10维参数
# 模拟每个样本的策略对数概率梯度
grad_log_probs = np.random.randn(N, dim)
# 随机向量v
v = np.random.randn(dim)

# 显式计算Fisher矩阵(仅用于验证,实际不这么做)
F_explicit = grad_log_probs.T @ grad_log_probs / N
Fv_explicit = F_explicit @ v

# 用fisher_vector_product计算(无需显式构建F)
Fv_fvp = fisher_vector_product(grad_log_probs, None, v)

# 验证两种方法的结果一致
error = np.max(np.abs(Fv_explicit - Fv_fvp))
print(f"Fisher向量积验证 - 最大误差: {error:.2e}")
print("误差极小,说明fisher_vector_product实现正确!")

print(f"\nF矩阵形状: {F_explicit.shape}")
print(f"F矩阵秩: {np.linalg.matrix_rank(F_explicit)} (样本数<N时F不满秩)")
"""
共轭梯度法(Conjugate Gradient)求解 F @ x = g
TRPO用CG法求解自然梯度方向,避免显式求逆F^{-1}
"""
import numpy as np

def conjugate_gradient(fvp_func, grad, grad_log_probs, 
                       n_iterations=10, damping=0.1, tol=1e-10):
    """
    共轭梯度法求解 F @ x = g
    
    参数:
    - fvp_func: Fisher向量积函数 f(v) = F @ v
    - grad: 右端向量 g [dim]
    - grad_log_probs: 对数概率梯度 [N, dim](传递给fvp_func)
    - n_iterations: 最大迭代次数
    - damping: 阻尼系数(改善F的条件数)
    - tol: 残差容限
    
    返回:
    - x: 近似解 F^{-1} @ g [dim]
    """
    dim = grad.shape[0]
    x = np.zeros(dim)  # 初始化解x=0
    # 残差 r = g - Fx,由于x=0,r初始=g
    r = grad.copy()    # (F @ 0 = 0,所以r = g)
    p = r.copy()       # 初始搜索方向 = 残差方向
    r_dot_old = np.dot(r, r)  # ||r||^2
    
    for i in range(n_iterations):
        # 计算 F @ p + damping * p(阻尼保证正定性)
        Fp = fvp_func(grad_log_probs, None, p)
        Fp_damped = Fp + damping * p
        
        # 步长 α = (r^T r) / (p^T Fp)
        alpha = r_dot_old / (np.dot(p, Fp_damped) + 1e-8)
        
        # 更新 x 和 r
        x = x + alpha * p
        r = r - alpha * Fp_damped
        
        # 检查收敛
        r_dot_new = np.dot(r, r)
        if r_dot_new < tol:
            print(f"CG收敛于第{i+1}次迭代")
            break
        
        # 更新搜索方向(β = r_new^T r_new / r_old^T r_old)
        beta = r_dot_new / (r_dot_old + 1e-8)
        p = r + beta * p
        r_dot_old = r_dot_new
    
    return x

# ===== 模拟测试共轭梯度法 =====
np.random.seed(42)
dim = 10
N = 100

# 构建一个正定的F矩阵(A^T A 总是半正定)
A = np.random.randn(N, dim)
F = A.T @ A / N  # Fisher信息矩阵的近似
g = np.random.randn(dim)  # 策略梯度

# 方法1:直接求逆(小维度可行,大维度不可行)
x_direct = np.linalg.solve(F + 0.1 * np.eye(dim), g)

# 方法2:共轭梯度法(适用于大维度)
# 先定义fvp函数
def fvp(grad_log_probs_vec, _, vector):
    """Fisher向量积 F @ v"""
    return grad_log_probs_vec.T @ (grad_log_probs_vec @ vector) / N

x_cg = conjugate_gradient(fvp, g, A, n_iterations=20)

# 比较两种方法
error = np.max(np.abs(x_direct - x_cg))
cos_sim = np.dot(x_direct, x_cg) / (np.linalg.norm(x_direct) * np.linalg.norm(x_cg) + 1e-8)
print(f"\n=== 共轭梯度法验证 ===")
print(f"直接求逆 vs CG: 最大误差 = {error:.2e}")
print(f"余弦相似度: {cos_sim:.6f}")
print(f"CG完全恢复了正确方向!注意CG比直接求逆快O(n^2)倍")
"""
TRPO线搜索(Line Search)
确保更新后的策略:(1)满足KL约束 (2)性能不下降
"""
import numpy as np

def compute_kl_divergence(log_probs_old, log_probs_new):
    """
    计算两个策略之间的KL散度
    KL(π_old || π_new) = E_{a~π_old}[log(π_old/π_new)]
    = Σ π_old(a) * (log π_old(a) - log π_new(a))
    """
    probs_old = np.exp(log_probs_old)  # [N, n_actions]
    # KL = Σ p * (log p - log q)
    kl = np.sum(probs_old * (log_probs_old - log_probs_new), axis=1)  # [N]
    return np.mean(kl)

def linesearch_trpo(old_params, step_direction, max_kl,
                    get_log_probs_fn, compute_surrogate_fn,
                    n_backtracks=10, backtrack_coeff=0.8):
    """
    TRPO的线搜索:指数回退找到满足约束的步长
    
    参数:
    - old_params: 旧策略参数 [dim]
    - step_direction: 自然梯度方向(已包含尺度)[dim]
    - max_kl: 允许的最大KL散度
    - get_log_probs_fn: 给定参数返回新旧对数概率的函数
    - compute_surrogate_fn: 计算代理目标函数值
    - n_backtracks: 最大回退次数
    - backtrack_coeff: 回退系数(每次乘以该系数)
    
    返回:
    - new_params: 满足约束的参数
    - accepted: 是否找到满足条件的步长
    """
    log_probs_old, _ = get_log_probs_fn(old_params)
    surrogate_old = compute_surrogate_fn(old_params)
    
    for i in range(n_backtracks):
        # 指数衰减步长: α = α_0 * coeff^i
        step_size = backtrack_coeff ** i
        new_params = old_params + step_size * step_direction
        
        # 检查KL约束
        _, log_probs_new = get_log_probs_fn(new_params)
        kl = compute_kl_divergence(log_probs_old, log_probs_new)
        
        if kl > max_kl:
            continue  # KL太大,回退
        
        # 检查性能提升
        surrogate_new = compute_surrogate_fn(new_params)
        if surrogate_new >= surrogate_old:
            print(f"线搜索成功!步长系数: {step_size:.4f}, KL: {kl:.6f}")
            return new_params, True
    
    print("线搜索未找到满足条件的步长,返回原参数")
    return old_params.copy(), False

# ===== 模拟线搜索测试 =====
np.random.seed(42)
dim = 10

# 模拟参数和方向
old_params = np.random.randn(dim)
direction = np.random.randn(dim)
direction = direction / np.linalg.norm(direction) * 0.1  # 归一化

# 模拟策略函数
def get_log_probs(params):
    """模拟:参数越小,对数概率越接近(KL越小)"""
    log_probs = np.random.randn(32, 4) * 0.1
    return log_probs, log_probs + params[0] * 0.01  # 旧和新

def compute_surrogate(params):
    """模拟:参数越大,目标函数值越大(但过大会违反KL)"""
    return -np.linalg.norm(params - old_params) + 0.5

success = linesearch_trpo(
    old_params, direction, max_kl=0.01,
    get_log_probs_fn=get_log_probs,
    compute_surrogate_fn=compute_surrogate
)
print("TRPO线搜索确保每一步都在「信任圈」内且性能不下降!")

概念关系图谱

概念与TRPO的关系说明
自然策略梯度核心方法TRPO用自然梯度代替普通梯度,考虑参数空间的统计距离
KL散度约束度量TRPO用KL散度衡量新旧策略的差异,作为信赖域半径
Fisher信息矩阵数学工具Fisher矩阵是KL散度的二阶近似,用于定义自然梯度
共轭梯度法求解方法用于高效求解F^{-1}g,避免显式矩阵求逆
线搜索步长选择在自然梯度方向上用指数回退找到满足约束的步长
重要性采样修正技巧用旧策略数据估计新策略期望,提高样本效率
PPO后继/简化PPO用clip机制替代TRPO的KL约束,实现更简单
单调改进定理理论保证TRPO的数学基础——保证策略性能单调提升

重点答疑

大白话 TRPO本质上是用「数学严格性」换「工程复杂性」——它有漂亮的单调改进定理做保证,但为了实现这个保证,代码复杂了十倍。PPO反其道而行之,用简单的clip换来了几乎同样的效果。
💡 核心要点:TRPO的精髓在于三个数学组件的协同——Fisher矩阵定义「统计距离」、共轭梯度高效求解「自然梯度」、线搜索确保「安全步长」。三者缺一不可。
    undefined

章节单词汇总

英文音标术语释义
Trust Region/trʌst ˈriːdʒən/信赖域目标函数的局部近似可靠的范围
Natural Policy Gradient/ˈnætʃərəl ˈpɒləsi ˈɡreɪdiənt/自然策略梯度考虑统计距离修正的策略梯度方向
Fisher Information Matrix/ˈfɪʃər ˌɪnfəˈmeɪʃən ˈmeɪtrɪks/Fisher信息矩阵对数似然梯度的协方差矩阵,度量统计敏感性
Conjugate Gradient/ˈkɒndʒʊɡət ˈɡreɪdiənt/共轭梯度法高效求解正定线性系统的迭代优化算法
Line Search/laɪn sɜːtʃ/线搜索沿给定方向寻找最优步长的算法
KL Divergence/keɪ el daɪˈvɜːdʒəns/KL散度衡量两个概率分布差异的度量
Damping/ˈdæmpɪŋ/阻尼加到对角线的正则化项,改善矩阵条件数
Monotonic Improvement/ˌmɒnəˈtɒnɪk ɪmˈpruːvmənt/单调改进每一步更新后策略性能不下降的性质
Backtracking/ˈbæktrækɪŋ/回退从大步长逐步减小直到满足条件的搜索策略
Hessian-vector Product/ˈhesiən ˈvektə ˈprɒdʌkt/Hessian向量积无需构建Hessian矩阵即可计算H·v的技术