10.3 PPO:近端策略优化

10.1 节 中,我们介绍了 Agentic-RL 的两阶段训练范式(SFT → RL)。RL 阶段的核心问题是:如何根据奖励信号来更新模型参数? 这正是策略优化算法要解决的问题。

本节将从零开始,系统性地讲解 PPO(Proximal Policy Optimization)算法——它是 InstructGPT 和 ChatGPT 的核心训练算法,也是理解后续 DPO、GRPO 算法的基础。我们将从最基本的直觉出发,逐步推导数学公式,并通过大量图示帮助理解。

三大策略优化算法架构对比

预备知识:策略梯度的基本思想

在深入三种算法之前,我们需要理解一个共同的起点——策略梯度(Policy Gradient) [1]。

核心直觉

想象你在练习投篮。每次投篮后,你会得到一个反馈:进了(奖励 +1)或没进(奖励 0)。策略梯度的思想极其朴素:

如果某个动作获得了高奖励,就增加该动作的概率;如果获得了低奖励,就降低该动作的概率。

形式化地,策略梯度定理给出的梯度方向为:

下面我们逐项拆解这个公式,并用一个具体的语言模型例子来帮助理解。


—— "我该往哪个方向调参数?"

  • 是我们的总目标:模型在所有可能输入上的期望累积奖励。 越大,模型整体表现越好
  • 是对模型参数 (即模型中数十亿个权重值)求梯度
  • 就是一个与 同维度的向量,它告诉我们:如果把每个参数往哪个方向微调一点点, 会增大最快
  • 训练时我们做的就是: 是学习率),即沿梯度方向"上坡"

类比:你蒙着眼站在山坡上,梯度就是"脚下最陡的上坡方向"。每走一步(更新一次参数),你就往山顶(最大奖励)靠近一点。


—— "怎样调参数才能让这个动作更可能发生?"

这是公式中最核心也最难理解的部分,我们分层解释:

第一层: 是什么?

就是我们的语言模型。给定当前状态 (对话历史 + 已生成的 token),它输出一个概率分布,表示"下一个 token 是什么"的概率。例如:

下一个 token ()概率
"搜索"0.35
"回答"0.25
"计算"0.20
"我"0.10
......

表示:在当前上下文下,模型认为下一步输出"搜索"的概率是 35%。

第二层: 为什么取对数?

取对数有两个好处:

  1. 数值稳定:概率值在 0 到 1 之间,连乘多个 token 的概率会变得极小(如 ),取对数后变为加法(),避免数值下溢
  2. 梯度形式简洁,这个比率形式恰好是我们需要的

第三层: 到底是什么?

这被称为得分函数(score function)。它是一个与模型参数 同维度的向量,指出:

"如果想让动作 在状态 下的概率增大,模型的每个参数应该分别往哪个方向调?"

  • 它不直接改变概率,而是给出一个方向
  • 沿这个方向调整参数 → 增大(这个动作变得更可能)
  • 沿相反方向调整参数 → 减小(这个动作变得更不可能)

类比:得分函数就像动作 的"方向盘"——转动它可以增加或减少这个动作被选中的概率。但光有方向盘不够,你还需要知道该转多少——这就是下面 的作用。


—— "这个方向盘该转多少?"

是整条轨迹(从开始到结束的完整交互过程)的累积奖励,它充当权重

  • (正奖励):说明这条轨迹整体表现不错

    • 梯度 = 正权重 × 得分函数 → 沿得分函数方向更新 → 增加轨迹中每个动作的概率
    • 直觉:这次表现好,下次要更多地做类似的事情
  • (负奖励):说明这条轨迹整体表现很差

    • 梯度 = 负权重 × 得分函数 → 沿得分函数反方向更新 → 减少轨迹中每个动作的概率
    • 直觉:这次表现差,下次要避免做类似的事情
  • (零奖励):这条轨迹对梯度无贡献

  • 越大:权重越大,这条轨迹对参数更新的影响越大。奖励/惩罚越极端,模型"记忆"越深刻

类比 就像教练的评分。得分函数指明了方向盘, 决定了转多大的角度。教练打高分()→ 大力转向"增加该动作概率";教练打低分()→ 大力转向"减少该动作概率"。


—— "轨迹中每一步都要算"

一条轨迹包含 个时间步(从 ),每一步都有一个 对。求和意味着:轨迹中每一步的得分函数都被同一个 加权

在语言模型中,一步 = 生成一个 token。如果模型生成了一个 50 token 的回答,,那么这 50 个 token 中每一个的生成概率都会被同一个总奖励加权更新。

注意:这其实是一个粗糙的做法——用整条轨迹的总奖励来加权每一步。如果轨迹中前 30 个 token 是正确推理,后 20 个 token 是错误结论,它们都会被总奖励同等对待。这就是"信用分配问题(credit assignment)"——PPO 的优势函数 正是为了解决这个问题(见 §1.3)。


—— "对很多次尝试取平均"

期望运算符 表示轨迹 是按策略 随机采样的。

  • 因为语言模型的生成是随机的(通过 temperature 采样),同一个输入可能产生不同的输出
  • 每次采样得到一条轨迹 ,对应一个
  • 期望就是对所有可能轨迹加权平均——概率越高的轨迹权重越大

实际操作中:我们无法枚举所有可能轨迹(语言模型的输出空间是天文数字级的),因此用蒙特卡洛近似——采样 条轨迹,取平均作为期望的估计:

越大,估计越准确,但计算成本也越高。这就是训练中 batch size 的本质。


完整例子:语言模型 Agent 的一次策略梯度更新

假设我们正在训练一个能调用搜索工具的 Agent,用户提问"北京今天天气如何?"

采样两条轨迹:

轨迹 A(好的回答)轨迹 B(差的回答)
用户:"北京今天天气如何?"用户:"北京今天天气如何?"
<think><think>
需要查询实时天气我直接回答吧
</think></think>
<tool_call>search("北京天气")</tool_call>北京今天晴天,25°C
...(获取结果后给出准确回答)(瞎编的,可能完全错误)
+0.8(调用了工具,回答准确)-0.2(没调用工具,回答错误)

梯度更新效果:

  • 轨迹 A):模型会增加"遇到实时信息问题 → 调用搜索工具"这一系列动作的概率
  • 轨迹 B):模型会减少"遇到实时信息问题 → 直接瞎编回答"这一系列动作的概率

经过成千上万次这样的更新,模型逐渐学会:遇到需要实时信息的问题,应该先调用工具,而不是直接编造答案。


原始策略梯度的缺陷

虽然直觉清晰,但原始策略梯度有两个严重问题:

问题具体表现后果
高方差 可能在不同轨迹间差异极大梯度估计不稳定,训练收敛极慢
步长不可控没有约束单步更新的大小一次"大跳步"就可能毁掉整个策略

PPO、DPO、GRPO 各自用不同方式解决了这两个问题。 本节详细讲解 PPO,DPO 和 GRPO 将分别在 10.410.5 中介绍。


1.1 PPO 解决了什么问题?

PPO [2] 是 OpenAI 于 2017 年提出的策略优化算法,是 InstructGPT [3] 和 ChatGPT 的核心训练算法。PPO 的设计目标是:

在保证训练稳定性的前提下,尽可能高效地利用已采样数据来更新策略。

PPO 通过两个关键机制实现这一目标:

  1. 重要性采样:允许用"旧策略"采集的数据来训练"当前策略"(数据复用)
  2. Clip 裁剪:限制策略更新的步长,防止策略崩溃

1.2 重要性采样比率 :离策训练的核心

在策略梯度中,我们需要从当前策略 中采样轨迹来计算梯度。但如果每次更新参数后都重新采样,效率极低。重要性采样 允许我们用旧策略 的样本来估计新策略 的梯度。

核心是引入重要性采样比率

逐项解读:

  • 分子 :当前策略在状态 下选择动作 的概率
  • 分母 :采样时的旧策略选择该动作的概率
  • :当前策略与旧策略对这个动作的偏好完全一致
  • :当前策略比旧策略更倾向于这个动作(即"当前策略认为这个动作更好了")
  • :当前策略比旧策略更不倾向于这个动作("当前策略认为这个动作变差了")
  • :当前策略选择该动作的概率是旧策略的 2 倍

重要性采样的直觉:假设旧策略采样到某个动作的概率是 10%(),而当前策略认为该动作概率应该是 30%(),则 。这意味着如果用旧策略的数据来估计新策略的期望,每条这样的数据应该被赋予 3 倍权重——因为新策略"本应"更频繁地采到它。

1.3 优势函数 :判断动作的"好坏"

策略梯度中,用累积奖励 作为权重会导致高方差。优势函数(Advantage Function) 通过引入一个基准线来解决这个问题:

  • :在状态 下执行动作 后,能获得的期望累积奖励(动作价值
  • :在状态 下,按当前策略执行所能获得的期望累积奖励(状态价值,即"基准线")
  • :动作 比"平均水平"好 → 应当强化
  • :动作 比"平均水平"差 → 应当抑制
  • :动作 与平均水平持平 → 无需调整

为什么减去基准线能降低方差? 一个形象的比喻:假设你考试得了 85 分。如果全班平均 60 分,你会觉得"考得不错"();如果全班平均 90 分,你会觉得"发挥失常"()。把绝对分数转为相对分数,消除了分数尺度的干扰,让信号更加稳定。

1.4 GAE:广义优势估计

实际训练中, 都不是精确已知的,需要用一个 Critic 模型 来估计。GAE(Generalized Advantage Estimation) [4] 是一种融合多步估计的方法,在偏差和方差之间取得平衡:

其中 TD 误差(Temporal Difference Error) 定义为:

逐项解读 TD 误差:

  • :在时间步 实际获得的即时奖励
  • :Critic 对下一状态的价值估计,乘以折扣因子
  • :Critic 对当前状态的价值估计
  • 直觉 衡量的是 "实际发生的"()与 "Critic 预期的"()之间的差异。如果 ,说明实际结果超出预期(惊喜!); 则说明实际结果不如预期(失望!)

逐项解读 GAE:

  • 指数衰减权重——越远的时间步,对当前优势的贡献越小
  • GAE 折衷参数,控制偏差-方差权衡:
GAE 退化为偏差方差直觉
单步 TD:高(完全依赖 Critic 精度)只看一步的"惊喜"
蒙特卡洛:低(用完整轨迹)看完整轨迹的表现
推荐值适中适中兼顾近期和远期信息

📌 关键问题:GAE 需要 Critic 模型 来估计状态价值。对于大语言模型(如 7B 参数),这意味着需要额外加载一个同等规模的 Critic 模型——这是 PPO 在大模型训练中最大的资源瓶颈。

1.5 PPO Clip 机制:策略更新的"安全绳"

有了优势 和比率 ,PPO 的损失函数为:

这个公式看起来复杂,但核心思想很简单。让我们分两种情况理解:

情况 ①:(好动作,应强化)

  • 无 Clip 时: 越大(即越增加该动作概率),损失越小 → 梯度会推动策略不断增加该动作概率
  • 有 Clip 时: 超过 后, 不再增大 → 取到裁剪值 → 梯度变为零
  • 效果:即使动作很好,也不允许概率增加太多(防止"过度自信")

情况 ②:(差动作,应抑制)

  • 无 Clip 时: 越小(即越降低该动作概率),损失越小 → 梯度会推动策略大幅降低该动作概率
  • 有 Clip 时: 低于 后,梯度变为零
  • 效果:即使动作很差,也不允许概率降低太多(防止"矫枉过正")

PPO Clip 机制图解

Clip 参数 的含义(通常取 0.1–0.3)定义了"信任域"的大小——每次策略更新,每个动作的概率变化不得超过旧策略的 倍。 越小越保守, 越大越激进。

1.6 KL 散度惩罚:防止策略"跑偏"的另一道保险

Clip 机制限制的是单个动作的概率变化幅度,但它无法约束策略整体分布的漂移。在 RLHF 场景中,如果模型为了追求高奖励而"忘记"了 SFT 阶段学到的通用语言能力(如语法、连贯性),就会出现语言退化奖励黑客(reward hacking)——模型找到某种"讨巧"的输出模式来骗取高奖励,但人类读起来一塌糊涂。

为此,PPO 在 RLHF 中通常会额外加入一个 KL 散度惩罚项,完整的优化目标变为:

其中 KL 散度定义为:

逐项解读:

  • 参考策略(Reference Policy)——详见下方专门说明
  • :衡量当前策略 相对于参考策略 分布偏离程度 表示两者完全一致; 越大,偏离越严重
  • KL 惩罚系数——控制"探索新策略"与"保持原有能力"之间的平衡

关于 KL 散度的数学定义和直觉解释,请参阅 附录 E:KL 散度详解

什么是 Reference 模型()?

Reference 模型是 RLHF 和 GRPO 中一个非常重要的概念,初学者容易将它与"旧策略"()混淆。我们用一个表格来彻底厘清:

Reference 模型 旧策略
是什么RL 训练开始前的 SFT 模型快照当前 RL 迭代开始时的策略模型快照
何时创建RL 训练启动时,复制一份 SFT 模型每次采样新数据前,复制当前 Policy 模型
是否更新永远不更新(冻结参数)✅ 每轮迭代更新一次(同步为当前策略)
存在目的充当"锚点",防止策略偏离 SFT 太远提供重要性采样的分母(数据复用)
作用范围整个 RL 训练过程仅在当前迭代内有效
用在哪里KL 散度惩罚 重要性采样比率

形象比喻

想象你在学开车(RL 训练)。驾校教练教会了你基本操作(SFT),现在你上路练习。

  • Reference 模型 = 教练手册上的标准操作规范。无论你练了多久,手册永远不变。它的作用是:如果你的驾驶习惯偏离标准太远(比如开始"飙车"),就把你拉回来。
  • 旧策略 = 你上一次练车结束时的驾驶水平。每次练完车,你的水平都会提升一点。它的作用是:评估你这次练车相比上次有哪些变化。

为什么需要 Reference 模型?

在 RL 训练过程中,模型会不断迭代更新。如果没有 Reference 模型作为锚点,可能出现以下问题:

  1. 奖励黑客(Reward Hacking):模型发现某种"讨巧"的输出模式可以骗取高奖励(如不断重复某个高奖励短语),但实际输出质量极差
  2. 语言退化(Language Degeneration):模型为追求奖励而丧失了 SFT 阶段学到的语法能力、连贯性和通用知识
  3. 模式坍缩(Mode Collapse):模型对所有输入都生成类似的"安全"回答,丧失多样性

KL 散度 就像一根"弹力绳"——当 偏离 越远,惩罚越大,把策略拉回来。

训练过程中三个模型的时间线示意

时间 →
              RL 迭代 1        RL 迭代 2        RL 迭代 3
              ─────────       ─────────       ─────────
π_ref:     [SFT 模型] ════════════════════════════════════  (永远不变)

π_θ_old:   [SFT 模型]──→[θ₁]──→[θ₁]──→[θ₂]──→[θ₂]──→[θ₃]  (每轮迭代开始时同步)
                采样↓      更新↑    采样↓      更新↑    采样↓
π_θ:       [SFT 模型]→→→[θ₁]  [θ₁]→→→[θ₂]  [θ₂]→→→[θ₃]    (持续训练更新)
  • 第 1 行: 始终是 SFT 模型,从头到尾不变
  • 第 2 行: 在每轮迭代开始时,从 复制一份快照
  • 第 3 行: 是持续训练更新的 Policy 模型

📌 实现细节:Reference 模型需要单独占用显存。对于 7B 参数模型(bf16),Reference 模型约占 14GB 显存。为了节省显存,有些实现会使用 LoRA 适配器——此时 Reference 模型不需要单独加载,只需在推理时关闭 LoRA 适配器即可得到 的输出。

的作用与调节

效果适用场景
过小(如 0.001)KL 约束几乎无效,策略可大幅偏离探索性强,但容易奖励黑客、语言退化
适中(如 0.01–0.1)平衡探索与约束,推荐起始值大多数 RLHF 场景
过大(如 1.0)策略几乎无法偏离 SFT 模型RL 训练形同虚设

自适应 KL 控制:InstructGPT [3] 提出了一种动态调节 的方法——设定一个目标 KL 值 ,如果实际 超过目标值,则增大 (收紧约束);反之则减小 (放松约束):

其中 是调节步长(通常取 0.1–0.2)。这种自适应机制让训练更加稳健——不需要手动精调

Clip + KL 的协同作用

约束机制约束对象约束粒度直觉
Clip单个动作的概率比 局部(逐 token 级别)"每一步不能走太远"
KL整体输出分布 vs 全局(策略级别)"总体路线不能偏离太远"

两者互补:Clip 防止单步更新过大,KL 防止累积漂移过大。在实践中,两者通常同时使用。

1.7 PPO 完整训练流程

PPO 完整训练迭代流程

PPO 的训练需要同时维护以下模型:

PPO 训练架构

1.8 PPO 核心代码实现

下面用 PyTorch 代码完整实现 PPO 的核心组件,帮助从代码层面理解每个公式的含义。

1.8.1 GAE 优势估计

import torch
import torch.nn.functional as F

def compute_gae(
    rewards: torch.Tensor,       # [T] 每步即时奖励
    values: torch.Tensor,        # [T+1] Critic 估计的状态价值(含终止状态)
    gamma: float = 1.0,          # 折扣因子(语言模型任务通常取 1.0)
    lam: float = 0.95,           # GAE λ 参数
) -> torch.Tensor:
    """
    计算 GAE(广义优势估计)

    公式:A_t = Σ (γλ)^l · δ_{t+l}
    其中 δ_t = r_t + γ·V(s_{t+1}) - V(s_t)

    Args:
        rewards:  每步获得的即时奖励,形状 [T]
        values:   Critic 对每个状态的价值估计,形状 [T+1]
                  (最后一个是终止状态的价值,通常为 0)
        gamma:    折扣因子,控制未来奖励的衰减
        lam:      GAE λ 参数,控制偏差-方差权衡
                  λ=0 → 单步 TD(低方差高偏差)
                  λ=1 → 蒙特卡洛(高方差低偏差)

    Returns:
        advantages: GAE 优势估计,形状 [T]
    """
    T = len(rewards)
    advantages = torch.zeros(T)
    gae = 0.0  # 从最后一步向前累积

    for t in reversed(range(T)):
        # TD 误差:δ_t = r_t + γ·V(s_{t+1}) - V(s_t)
        # "实际发生的" - "Critic 预期的"
        delta = rewards[t] + gamma * values[t + 1] - values[t]

        # GAE 递推:A_t = δ_t + γλ·A_{t+1}
        # 等价于 A_t = Σ (γλ)^l · δ_{t+l}
        gae = delta + gamma * lam * gae
        advantages[t] = gae

    return advantages

1.8.2 PPO Clip 损失

def ppo_clip_loss(
    log_probs: torch.Tensor,         # [B, T] 当前策略的 log π_θ(a_t|s_t)
    old_log_probs: torch.Tensor,     # [B, T] 旧策略的 log π_θ_old(a_t|s_t)
    advantages: torch.Tensor,         # [B, T] GAE 优势估计
    clip_epsilon: float = 0.2,        # Clip 范围 ε
) -> tuple[torch.Tensor, dict]:
    """
    计算 PPO Clip 策略损失

    公式:L = -E[min(ρ_t·A_t, clip(ρ_t, 1-ε, 1+ε)·A_t)]

    Args:
        log_probs:     当前策略对每个 token 的对数概率 [batch, seq_len]
        old_log_probs: 旧策略对每个 token 的对数概率 [batch, seq_len]
        advantages:    每个 token 的优势值 [batch, seq_len]
        clip_epsilon:  裁剪范围,通常 0.1-0.3

    Returns:
        loss: 策略损失标量
        metrics: 监控指标
    """
    # ── 计算重要性采样比率 ────────────────────────────────────────
    # ρ_t = π_θ(a_t|s_t) / π_θ_old(a_t|s_t)
    # 在 log 空间做除法 = 做减法,再 exp 回去
    ratio = torch.exp(log_probs - old_log_probs)  # [B, T]

    # ── 未裁剪的目标 ──────────────────────────────────────────────
    # ρ_t · A_t
    surr1 = ratio * advantages  # [B, T]

    # ── 裁剪后的目标 ──────────────────────────────────────────────
    # clip(ρ_t, 1-ε, 1+ε) · A_t
    clipped_ratio = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon)
    surr2 = clipped_ratio * advantages  # [B, T]

    # ── PPO 损失 = -min(surr1, surr2) ────────────────────────────
    # 取 min 确保:
    #   A>0 时:不让 ρ 超过 1+ε(防止过度强化)
    #   A<0 时:不让 ρ 低于 1-ε(防止过度抑制)
    loss = -torch.min(surr1, surr2).mean()

    # ── 监控指标 ──────────────────────────────────────────────────
    with torch.no_grad():
        # 裁剪比例:被裁剪的 token 占比(健康范围 0.1-0.3)
        clip_fraction = ((ratio - 1.0).abs() > clip_epsilon).float().mean().item()
        # 近似 KL 散度(用于监控策略偏移程度)
        approx_kl = (old_log_probs - log_probs).mean().item()

    metrics = {
        "policy_loss": loss.item(),
        "clip_fraction": clip_fraction,       # > 0.5 说明更新步长太大
        "approx_kl": approx_kl,               # > 0.02 可能需要减小学习率
        "mean_ratio": ratio.mean().item(),     # 应接近 1.0
    }

    return loss, metrics

1.8.3 Critic(价值函数)损失

def critic_loss(
    values: torch.Tensor,        # [B, T] Critic 预测的 V(s_t)
    returns: torch.Tensor,       # [B, T] 实际回报 = advantages + old_values
) -> torch.Tensor:
    """
    计算 Critic 损失(均方误差)

    Critic 的目标:让 V_ϕ(s_t) 尽可能接近实际回报

    Args:
        values:  Critic 对每个状态的价值预测 [batch, seq_len]
        returns: 实际回报(GAE 优势 + 旧价值估计)[batch, seq_len]

    Returns:
        loss: Critic 损失标量
    """
    return F.mse_loss(values, returns)

1.8.4 KL 散度惩罚

def kl_penalty(
    log_probs: torch.Tensor,     # [B, T] 当前策略的 log π_θ
    ref_log_probs: torch.Tensor, # [B, T] 参考策略(SFT 模型)的 log π_ref
) -> torch.Tensor:
    """
    计算当前策略与参考策略之间的 KL 散度

    公式:D_KL(π_θ ‖ π_ref) = E[log(π_θ/π_ref)]

    Args:
        log_probs:     当前策略的对数概率
        ref_log_probs: 参考策略的对数概率(SFT 模型,冻结参数)

    Returns:
        kl: KL 散度标量
    """
    # KL = E[log π_θ - log π_ref]
    kl = (log_probs - ref_log_probs).mean()
    return kl

1.8.5 PPO 完整训练步骤

def ppo_training_step(
    policy_model,          # 策略模型 π_θ(可训练)
    critic_model,          # Critic 模型 V_ϕ(可训练)
    ref_model,             # 参考模型 π_ref(冻结,不训练)
    input_ids,             # 输入 token ids
    response_ids,          # 生成的回答 token ids
    old_log_probs,         # 旧策略的 log 概率
    rewards,               # 奖励信号
    old_values,            # 旧 Critic 的价值估计
    clip_epsilon=0.2,      # PPO Clip ε
    kl_coef=0.01,          # KL 惩罚系数 β
    critic_coef=0.5,       # Critic 损失权重
    gamma=1.0,             # 折扣因子
    lam=0.95,              # GAE λ
):
    """
    PPO 单步训练的完整流程

    总损失 = 策略损失 + β·KL惩罚 + c·Critic损失
    """
    # ── Step 1: 计算当前策略的 log 概率 ──────────────────────────
    # 前向传播,获取当前策略对每个 token 的对数概率
    logits = policy_model(input_ids=input_ids, response_ids=response_ids)
    log_probs = compute_token_log_probs(logits, response_ids)  # [B, T]

    # ── Step 2: 计算 Critic 的价值估计 ───────────────────────────
    values = critic_model(input_ids=input_ids, response_ids=response_ids)  # [B, T]

    # ── Step 3: 计算 GAE 优势 ────────────────────────────────────
    # 对 batch 中的每个样本分别计算 GAE
    advantages_list = []
    for b in range(rewards.shape[0]):
        adv = compute_gae(rewards[b], old_values[b], gamma=gamma, lam=lam)
        advantages_list.append(adv)
    advantages = torch.stack(advantages_list)  # [B, T]

    # 标准化优势(减小方差,稳定训练)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # 实际回报 = 优势 + 旧价值估计(用于训练 Critic)
    returns = advantages + old_values[:, :-1]

    # ── Step 4: 计算 PPO Clip 策略损失 ───────────────────────────
    policy_loss, policy_metrics = ppo_clip_loss(
        log_probs, old_log_probs, advantages, clip_epsilon
    )

    # ── Step 5: 计算 KL 散度惩罚 ────────────────────────────────
    with torch.no_grad():
        ref_logits = ref_model(input_ids=input_ids, response_ids=response_ids)
        ref_log_probs = compute_token_log_probs(ref_logits, response_ids)
    kl = kl_penalty(log_probs, ref_log_probs)

    # ── Step 6: 计算 Critic 损失 ────────────────────────────────
    c_loss = critic_loss(values[:, :-1], returns)

    # ── Step 7: 总损失 ──────────────────────────────────────────
    # 策略损失(让模型做对的事)+ KL 惩罚(不要忘了语言能力)+ Critic 损失(学会评估好坏)
    total_loss = policy_loss + kl_coef * kl + critic_coef * c_loss

    return total_loss, {
        **policy_metrics,
        "kl_divergence": kl.item(),
        "critic_loss": c_loss.item(),
        "total_loss": total_loss.item(),
    }

📌 与 GRPO 代码的关键差异

  • PPO 需要额外维护一个 Critic 模型critic_model),它与 Policy 模型同等规模
  • 优势估计使用 GAE 递推(依赖 Critic),而不是 GRPO 的组内标准化
  • 总损失包含三部分:策略损失 + KL 惩罚 + Critic 损失,而 GRPO 没有 Critic 损失
  • 这也解释了为什么 PPO 的显存需求是 ≈ 3× 模型大小(Policy + Critic + Reference)

1.9 PPO 的优缺点总结

维度评价
通用性几乎适用于所有 RL 场景,不限于语言模型
稳定性Clip 机制提供了可靠的训练稳定性保障
理论基础有成熟的理论支撑(信任域方法的简化)
显存需求需要 Critic 模型,显存占用 ≈ 3× 模型大小
训练复杂Critic 与 Policy 互相依赖,联合训练不稳定
超参数多GAE λ、Critic 学习率、clip ε、KL β 等需精心调节


PPO 是一个强大但资源消耗较大的算法——需要 Critic 模型、奖励模型和在线采样。下一节将介绍 DPO,它通过精妙的数学推导直接跳过了奖励模型和 Critic,将 RL 问题转化为简单的监督学习。


参考文献

[1] WILLIAMS R J. Simple statistical gradient-following algorithms for connectionist reinforcement learning[J]. Machine Learning, 1992, 8(3): 229-256.

[2] SCHULMAN J, WOLSKI F, DHARIWAL P, et al. Proximal policy optimization algorithms[R]. arXiv preprint arXiv:1707.06347, 2017.

[3] OUYANG L, WU J, JIANG X, et al. Training language models to follow instructions with human feedback[C]//Advances in Neural Information Processing Systems (NeurIPS). 2022.

[4] SCHULMAN J, MORITZ P, LEVINE S, et al. High-dimensional continuous control using generalized advantage estimation[C]//International Conference on Learning Representations (ICLR). 2016.