本章定位 :本系列最关键的一章。TRPO 用 trust region 思想第一次稳定了 policy gradient 训练;PPO 把 trust region 工程化为简单的 clip 操作,成为 RLHF 时代的”瑞士军刀”。理解 PPO 就理解了《学习笔记-大模型》Ch6 的 RLHF 核心 。
承上 :Ch5 PG 定理 + Ch6 GAE。启下 :Ch9 SAC 给连续控制最强基线;Ch11 引出 offline RL 与 DPO 的关系。
§A 数学原理 1. 为什么 PG 训练这么不稳? 回忆 Ch5/Ch6 的痛点:策略梯度方向是对的,但步长无法控制 。
步长太小 → 学得慢
步长太大 → 策略一步走偏,后续轨迹质量崩坏
崩了之后无法恢复 ——因为坏策略采样到的全是坏数据
核心问题 :PG 是 non-stationary optimization——当前梯度只在当前策略附近”近似有效”,远离当前策略后,梯度估计全部失效。
解法思路 :限制每次更新的”策略变化幅度”,确保新策略 $\pi_{\theta_{\text{new}}}$ 与旧策略 $\pi_{\theta_{\text{old}}}$ 不要差太远。这就是 Trust Region (信赖域)思想。
2. TRPO:单调改进定理 2.1 期望回报的差分表达 关键引理 (Kakade & Langford 2002):
其中 $d^{\pi’}$ 是新策略下的状态访问分布。
直观 :新策略的提升 = 在新策略访问的状态分布下,新策略选的 action 在旧策略 advantage 上的期望。
2.2 局部近似(关键近似) 上式中 $d^{\pi’}$ 难以计算(依赖未知的 $\pi’$)。TRPO 用旧策略的状态分布 $d^\pi$ 近似:
接着用重要性采样把 $a \sim \pi’$ 转回 $a \sim \pi$:
记 $r(s, a) = \pi’(a \mid s) / \pi(a \mid s)$ 为 importance ratio 。
2.3 单调改进定理 定理 (Schulman et al. 2015):
其中 $C$ 是与 $\gamma$ 和 reward 范围相关的常数,$D_{KL}^{\max}$ 是 $\pi, \pi’$ 在所有状态下 KL 散度的最大值。
含义 :只要新策略 $\pi’$ 在 KL 意义下离 $\pi$ 不远,$\tilde{J}(\pi’)$ 是 $J(\pi’)$ 的下界 。优化下界 → 保证真实 $J$ 也在变好(单调改进 )。
2.4 TRPO 的优化形式 把上面变成约束优化:
实际中用 mean KL(而非 max KL)作为约束,$\delta \approx 0.01$。
2.5 TRPO 的工程难点 求解上面的约束优化需要:
计算 KL 约束的 Hessian(Fisher 信息矩阵 $F$)
用共轭梯度求解 $F^{-1} g$(natural gradient)
用 line search 确保 KL 约束满足
→ 二阶优化 ,实现复杂、训练慢。
3. PPO:把 Trust Region 一阶化 PPO(Schulman 2017)的目标:保留 trust region 思想,但只用一阶优化器(Adam)。
3.1 PPO-Penalty(早期版本,已少用) 把 KL 约束变成软惩罚:
$\beta$ 自适应调整(KL 大就增大 β)。但效果不稳定。
3.2 PPO-Clip(主流版本) 核心思路 :直接限制 importance ratio 的幅度。
其中:
$r_t(\theta) = \pi_\theta(a_t \mid s_t) / \pi_{\theta_{\text{old}}}(a_t \mid s_t)$
$\epsilon = 0.1$ 或 $0.2$
$\hat{A}_t$ 用 GAE 估计(Ch6 §A.3.4)
3.3 PPO-Clip 的几何理解 考虑两种情况:
情况 1:$\hat{A}_t > 0$(动作好,想增大概率)
$r_t < 1 + \epsilon$:自由更新
$r_t \geq 1 + \epsilon$:clip 截断梯度,不再奖励”大幅推高”
情况 2:$\hat{A}_t < 0$(动作差,想降低概率)
$r_t > 1 - \epsilon$:自由更新
$r_t \leq 1 - \epsilon$:clip 截断梯度,不再惩罚”大幅压低”
核心保护 :每次更新只允许 $\pi_\theta$ 在 ratio $\in [1-\epsilon, 1+\epsilon]$ 内变化。
3.4 完整 PPO 损失 其中:
$\mathcal{L}^{\text{VF}} = (V_\phi(s_t) - V_t^{\text{target}})^2$(Critic 损失)
$\mathbb{E}[H(\pi_\theta)]$(熵正则,鼓励探索)
$c_v = 0.5$,$c_h = 0.01$
3.5 PPO 的核心 trick:minibatch 多轮更新 每次收集一批 rollout 数据后,用同一批数据训练 K 轮 (通常 K=4),每轮内部分成多个 minibatch。
为什么可以?因为:
$\theta_{\text{old}}$ 固定(rollout 时的策略)
ratio $r_t$ 自然把”采样分布的差异”修正回来
clip 保证 $\theta$ 不会跑太远
这让 PPO 数据效率显著优于 vanilla PG 。
4. PPO 在 LLM 中的应用(RLHF) 把 PPO 用在 LLM 上有几个特殊性:
特性
LLM 与经典 RL 的差异
状态
完整 prompt + 已生成 token
动作
词表中的下一个 token(动作空间几万-几十万)
轨迹长度
几十到几千 token
奖励
仅在 EOS 时由 RM 给一个标量
per-token KL
KL penalty 加在每个 token 上,作为 reward shaping
详见《学习笔记-大模型》Ch6(PPO RLHF)。
§B 模型架构 B.1 PPO 数据流(重要!) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 ┌──────────────────────────────────────────────────────────┐ │ Phase 1: Rollout(采样数据,无梯度) │ │ for t = 1...T: │ │ a_t ~ π_θ_old(·|s_t) │ │ log_prob_old, value_old = policy_old.act(s_t) │ │ s_{t+1}, r_t = env.step(a_t) │ └──────────────────────────────────────────────────────────┘ ↓ ┌──────────────────────────────────────────────────────────┐ │ Phase 2: 计算 GAE │ │ advantages, returns = compute_gae(rewards, values, ...) │ └──────────────────────────────────────────────────────────┘ ↓ ┌──────────────────────────────────────────────────────────┐ │ Phase 3: K 轮 minibatch 更新 │ │ for epoch = 1...K: │ │ for batch in shuffle(dataset): │ │ log_prob_new = π_θ.log_prob(a_t) │ │ ratio = exp(log_prob_new - log_prob_old) │ │ loss = -min(ratio·A, clip(ratio, 1-ε, 1+ε)·A) │ │ + c_v · (V(s_t) - returns)² │ │ - c_h · entropy │ │ loss.backward(); optim.step() │ └──────────────────────────────────────────────────────────┘
B.2 PPO 完整 PyTorch 实现 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 import torchimport torch.nn as nnimport torch.nn.functional as Fimport numpy as npimport gymnasium as gymfrom torch.optim import Adamclass ActorCritic (nn.Module): """Categorical Actor + V Critic""" def __init__ (self, obs_dim, n_actions, hidden=64 ): super ().__init__() self .shared = nn.Sequential( nn.Linear(obs_dim, hidden), nn.Tanh(), nn.Linear(hidden, hidden), nn.Tanh(), ) self .actor = nn.Linear(hidden, n_actions) self .critic = nn.Linear(hidden, 1 ) def forward (self, obs ): h = self .shared(obs) return self .actor(h), self .critic(h).squeeze(-1 ) def act (self, obs ): logits, value = self .forward(obs) dist = torch.distributions.Categorical(logits=logits) a = dist.sample() log_prob = dist.log_prob(a) return a, log_prob, value def evaluate (self, obs, action ): """给定 obs 和 action,返回 (log_prob, value, entropy)""" logits, value = self .forward(obs) dist = torch.distributions.Categorical(logits=logits) log_prob = dist.log_prob(action) entropy = dist.entropy() return log_prob, value, entropy def ppo_train (env_name="CartPole-v1" , n_steps=int (2e5 ), n_envs=8 , rollout_len=128 , n_epochs=4 , n_minibatches=4 , gamma=0.99 , lam=0.95 , eps_clip=0.2 , lr=3e-4 , c_v=0.5 , c_h=0.01 , max_grad_norm=0.5 ): envs = gym.vector.SyncVectorEnv([lambda : gym.make(env_name) for _ in range (n_envs)]) obs_dim = envs.single_observation_space.shape[0 ] n_actions = envs.single_action_space.n model = ActorCritic(obs_dim, n_actions) optim = Adam(model.parameters(), lr=lr) obs, _ = envs.reset() total_steps = 0 while total_steps < n_steps: obs_buf = torch.zeros(rollout_len, n_envs, obs_dim) act_buf = torch.zeros(rollout_len, n_envs, dtype=torch.long) logp_buf = torch.zeros(rollout_len, n_envs) rew_buf = torch.zeros(rollout_len, n_envs) val_buf = torch.zeros(rollout_len, n_envs) done_buf = torch.zeros(rollout_len, n_envs) for t in range (rollout_len): obs_t = torch.tensor(obs, dtype=torch.float32) with torch.no_grad(): a, log_prob, value = model.act(obs_t) obs_new, r, terminated, truncated, _ = envs.step(a.numpy()) done = np.logical_or(terminated, truncated).astype(np.float32) obs_buf[t] = obs_t act_buf[t] = a logp_buf[t] = log_prob val_buf[t] = value rew_buf[t] = torch.tensor(r, dtype=torch.float32) done_buf[t] = torch.tensor(done) obs = obs_new total_steps += n_envs with torch.no_grad(): obs_t = torch.tensor(obs, dtype=torch.float32) _, last_val = model.forward(obs_t) last_val = last_val.detach() advantages = torch.zeros(rollout_len, n_envs) last_gae = torch.zeros(n_envs) for t in reversed (range (rollout_len)): next_val = last_val if t == rollout_len - 1 else val_buf[t + 1 ] delta = rew_buf[t] + gamma * next_val * (1 - done_buf[t]) - val_buf[t] last_gae = delta + gamma * lam * (1 - done_buf[t]) * last_gae advantages[t] = last_gae returns = advantages + val_buf b_obs = obs_buf.reshape(-1 , obs_dim) b_act = act_buf.reshape(-1 ) b_logp = logp_buf.reshape(-1 ) b_adv = advantages.reshape(-1 ) b_ret = returns.reshape(-1 ) batch_size = rollout_len * n_envs mb_size = batch_size // n_minibatches idx = np.arange(batch_size) for epoch in range (n_epochs): np.random.shuffle(idx) for start in range (0 , batch_size, mb_size): mb_idx = idx[start: start + mb_size] new_logp, new_val, entropy = model.evaluate(b_obs[mb_idx], b_act[mb_idx]) ratio = torch.exp(new_logp - b_logp[mb_idx]) adv = b_adv[mb_idx] adv = (adv - adv.mean()) / (adv.std() + 1e-8 ) surr1 = ratio * adv surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * adv actor_loss = -torch.min (surr1, surr2).mean() v_loss = (new_val - b_ret[mb_idx]).pow (2 ).mean() entropy_loss = -entropy.mean() loss = actor_loss + c_v * v_loss + c_h * entropy_loss optim.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optim.step() return model
这段代码是 PPO 的”生产级最简”实现,约 100 行。CleanRL 的实现与之高度一致。
B.3 PPO 在 LLM 上的对应(呼应 LLM 笔记 Ch6) LLM RLHF 中 PPO 的差别(参见之前的《学习笔记-大模型》Ch6):
1 2 3 4 5 6 7 rew_buf[t] = env.step(a) kl = log_prob_new - log_prob_ref reward = -beta * kl reward[-1 ] += rm_score
其余(GAE、clip、minibatch 多轮更新)完全一致。
§C 训练与推理 C.1 PPO 调参经验
参数
推荐值
备注
eps_clip
0.1-0.2
0.2 标配;连续控制可降到 0.1
n_epochs
4-10
多了会过拟合”旧分布”
lam
0.95
GAE 的 sweet spot
gamma
0.99
视任务远视程度调
lr
3e-4
Adam 默认
c_v
0.5
太大 critic 主导
c_h
0.01
复杂任务调到 0.001-0.05
max_grad_norm
0.5
必须有,否则会爆
rollout_len × n_envs
~2048 总样本
太小方差大、太大数据陈旧
C.2 实验:CartPole + PPO 1 model = ppo_train("CartPole-v1" , n_steps=int (1e5 ))
典型结果:
1万步:reward = 50
3万步:reward = 200
5万步:reward = 500(满分)
PPO 在 CartPole 上的收敛速度通常比 A2C 快 1.5-2 倍,且更稳定。
C.3 PPO 推理 PPO 推理时只用 Actor(与 A2C 一致),完全去掉 Critic:
1 2 3 4 5 6 def inference (model, obs ): obs_t = torch.tensor(obs, dtype=torch.float32).unsqueeze(0 ) with torch.no_grad(): logits, _ = model.forward(obs_t) a = logits.argmax(dim=-1 ).item() return a
C.4 工程经验 1. Reward normalization :复杂任务(MuJoCo)reward 尺度差异大,必须用 running mean/std 归一化。
2. Observation normalization :连续状态空间也用类似归一化。
3. Clip 不只是 actor :PPO2 还可以 clip critic 的 value 更新(防止 critic 学过头)。
4. Hyper-parameter 互相影响 :eps_clip 大 → n_epochs 要小;反之亦然。
5. KL 监控 :训练时打印 mean KL,若 KL 突然飙升(> 0.05),说明 clip 失效,需调小 lr。
§D 章末速查 D.1 关键公式
#
公式
含义
1
$r_t(\theta) = \pi_\theta(a_t\
s_t) / \pi_{\theta_{\text{old}}}(a_t\
s_t)$
importance ratio
2
$\mathcal{L}^{\text{CLIP}} = \min(r_t \hat{A}_t, \text{clip}(r_t, 1\pm\epsilon)\hat{A}_t)$
PPO 目标
3
$J(\pi’) - J(\pi) \approx \mathbb{E}_{s, a \sim \pi}[r(s,a) A^\pi(s,a)]$
单调改进引理
4
$J(\pi’) \geq \tilde{J}(\pi’) - C \cdot D_{KL}^{\max}$
TRPO 下界
5
$\hat{A}_t = \delta_t + \gamma\lambda \hat{A}_{t+1}$
GAE 递推(Ch6)
D.2 常见面试题 Q1:PPO 的 clip 在数学上等价于 trust region 吗?
不严格等价 ,只是工程化的近似
TRPO 用 KL 硬约束 + 二阶优化
PPO 用 ratio clip + 一阶优化
但在大多数实验上 PPO ≥ TRPO,且简单很多
Q2:为什么 PPO 能用同一批数据训 K 轮?
$\theta_{\text{old}}$ 固定时,importance ratio $r_t$ 自然修正分布偏差
clip 保证 $\theta$ 不会跑远
K 轮内策略仍在”trust region”内,数据有效
K 太大(>10)会过拟合旧分布——KL 飙升
Q3:PPO 与 A2C 的关系?
PPO = A2C + importance ratio + clip + 多轮更新
数据效率:PPO ≫ A2C
稳定性:PPO ≫ A2C(clip 保护)
实现复杂度:PPO 略高
Q4:PPO 为什么在 LLM RLHF 中是首选?
离散动作空间(token),PPO 天然支持
训练稳定(KL clip 防策略崩坏)
数据效率(rollout 一次更新多轮)
工程社区成熟(TRL/DeepSpeed 等都有现成实现)
Q5:PPO 的”Proximal”在哪?
$\epsilon$-clip 限制策略变化幅度
让 $\pi_\theta$ 始终”接近 (proximal to)” $\pi_{\theta_{\text{old}}}$
这是 trust region 思想的本质
承上启下 PPO 是 RL 与 LLM 时代的桥梁:
经典 RL 中用它处理离散/连续控制
LLM RLHF 中用它把模型对齐到人类偏好
下一章 Ch9 探讨连续控制的另一条路 ——DDPG 与 SAC。在机器人、自动驾驶等连续动作场景,SAC 是当前 SOTA。
理解了 PPO + SAC,你就掌握了现代 RL 的两大主流路径。