10.2 SFT + LoRA:监督微调与参数高效训练

监督微调的形式化定义

SFT(Supervised Fine-Tuning,监督微调) 是 Agentic-RL 训练的第一阶段,其目标是将通用基座模型 调整为具备特定 Agent 行为格式的初始策略

形式化地,SFT 的训练目标是在专家示范数据集 上最小化负对数似然损失:

逐项解读:

  • :对 条训练样本求平均,使损失不受数据集规模影响
  • :对第 条样本的输出序列中每个 token 求和——这是语言模型的自回归分解,即序列的联合概率等于每个 token 条件概率的乘积(对数域为求和)
  • :模型在已知输入 和前 个 token 的条件下,预测第 个 token 为 的对数概率。这个值越大(越接近 0),说明模型对这个 token 的预测越确信
  • 加负号后,最小化损失 = 最大化对数似然 = 让模型对专家示范的每个 token 都尽可能高的概率生成

自回归的直觉:训练时,模型看到的是“正确答案”(teacher forcing)——即使模型在第 步预测错了,第 步仍然使用真实的 作为条件。这使得 SFT 训练非常高效,但也带来了训练-推理分布偏移的问题(推理时模型看到的是自己生成的 token,而非专家示范)。

其中 是输入上下文(系统提示 + 用户问题), 是专家示范的完整输出序列(包含推理过程和工具调用)。

SFT 在 Agentic-RL 中的角色

SFT 阶段解决的核心问题是策略初始化:从随机初始化或通用预训练权重出发,直接进行 RL 训练会面临极高的探索成本——模型需要从海量随机输出中偶然发现有效的 Agent 行为格式。SFT 通过模仿学习将策略分布约束在合理的行为空间内,为后续 RL 优化提供稳定的起点。

SFT → RL 两阶段训练流程

SFT 训练数据格式规范

对于 Agent 任务,SFT 的训练数据是多轮对话 + 工具调用的完整轨迹,遵循 ChatML 格式:

{
  "conversations": [
    {
      "role": "system",
      "content": "你是一个数学助手,可以使用 calculator 工具进行精确计算。"
    },
    {
      "role": "user",
      "content": "一个圆形水池半径 3 米,水深 1.5 米,水池中有多少立方米的水?"
    },
    {
      "role": "assistant",
      "content": "<think>\n圆柱体体积公式:V = π × r² × h\n其中 r = 3 m,h = 1.5 m\n需要精确计算 π × 9 × 1.5,使用计算器工具。\n</think>\n\n<tool_call>\ncalculator(expression=\"3.14159265 * 3**2 * 1.5\")\n</tool_call>"
    },
    {
      "role": "tool",
      "content": "{\"result\": 42.4115}"
    },
    {
      "role": "assistant",
      "content": "水池中的水量约为 **42.41 立方米**。\n\n计算过程:V = π × r² × h = π × 3² × 1.5 ≈ 42.41 m³"
    }
  ]
}

数据质量优先原则

LIMA [1] 的研究提供了重要的实证依据:1,000 条精心筛选的高质量数据,其效果往往优于 10,000 条噪声数据。对于 Agent SFT,数据质量的评估维度如下:

质量维度标准验证方法
格式一致性统一的 <think>/<tool_call> 标签格式正则表达式自动检查
工具调用正确性参数类型、名称与工具定义完全匹配静态解析验证
推理连贯性<think> 内容与最终动作逻辑一致人工抽样审查
任务覆盖度覆盖所有工具的调用模式和边界情况工具调用分布统计
难度分布简单/中等/复杂任务比例均衡人工分级标注

推荐数据规模:500–2,000 条经过人工验证的高质量 Agent 交互轨迹。


全参数微调的资源困境

在理解 LoRA 之前,需要先明确全参数微调(Full Fine-Tuning)面临的根本性挑战。

以 Llama 3.1 8B 为例,训练时的显存需求分析如下:

# 显存需求精确估算(以 Llama 3.1 8B 为例)
total_params = 8_000_000_000        # 80 亿参数

# 推理阶段(float16)
inference_memory = total_params * 2 / (1024**3)   # ≈ 14.9 GB

# 训练阶段(Adam 优化器,混合精度)
# 参数(float32):4 bytes/param
# 梯度(float32):4 bytes/param
# Adam 一阶矩(float32):4 bytes/param
# Adam 二阶矩(float32):4 bytes/param
# 合计:16 bytes/param
training_memory = total_params * 16 / (1024**3)   # ≈ 119.2 GB

# 结论:全参数微调需要至少 2-3 张 A100 80GB
# 对大多数团队而言,这一成本难以承受

这一资源困境催生了参数高效微调(Parameter-Efficient Fine-Tuning, PEFT) 方法的研究,其中 LoRA 是目前最广泛应用的方案。


LoRA:低秩适应的理论基础

核心假设与数学推导

LoRA(Low-Rank Adaptation) [2] 的理论基础来自一个关键的实证发现:

预训练模型在微调过程中,权重更新矩阵 具有显著的低秩特性(intrinsic low rank)。

为什么微调的更新是低秩的? 直觉上,预训练模型已经学会了丰富的通用表示,微调只需要在这个表示空间中做小幅度的“方向调整”。这种调整展开在一个低维子空间中,而非需要改变所有 个方向。Aghajanyan et al. [4] 通过实验证明,微调的“内在维度”(intrinsic dimensionality)远小于模型参数量。

这意味着微调时真正需要改变的“信息量”远小于模型参数量。基于此,LoRA 提出以下参数化方案:

对于原始权重矩阵 ,不直接更新 ,而是将权重更新分解为两个低秩矩阵的乘积:

逐项解读:

  • 冻结的预训练权重,训练期间完全不变,保留模型的通用知识
  • 下投影矩阵,将 维输入压缩到 维的低秩空间;使用高斯分布初始化,保证训练开始时有非零梯度
  • 上投影矩阵,将 维低秩表示映射回 维输出空间;初始化为全零,确保训练开始时 ,即模型行为与基座模型完全相同,训练稳定性有保障
  • ,控制参数量(通常取 8, 16, 32, 64)。 越小,参数量越少,但表达能力越弱; 越大,表达能力越强,但参数量增加
  • 缩放因子,控制 LoRA 更新的幅度。设置为 而非直接用 ,是为了使实际缩放幅度与 的选择解耦合:当 时,无论 取何值,实际缩放因子始终为 2,方便跨不同 值的实验对比

前向传播

训练时冻结 ,仅更新

参数效率分析

LoRA 的参数量压缩比可以精确计算:

时,压缩比 ,即仅需训练不到 1% 的参数。

# 参数量对比(以 Llama 3.1 8B 的单个注意力投影层为例)
d, k = 4096, 4096

# 原始层参数量
original_params = d * k                    # = 16,777,216 (16M)

# LoRA 参数量(r=16)
r = 16
lora_params = r * k + d * r               # A: 65,536 + B: 65,536 = 131,072 (128K)

# 单层压缩比
compression_ratio = lora_params / original_params   # ≈ 0.78%

# 全模型 LoRA 参数量(应用于所有注意力层)
# Llama 3.1 8B: 32 层,每层 4 个投影(q, k, v, o)
num_lora_layers = 32 * 4
total_lora_params = num_lora_layers * lora_params   # ≈ 16.8M

# 全模型压缩比
total_compression = total_lora_params / total_params  # ≈ 0.21%
# 仅需训练不到 0.3% 的参数!

关键超参数的选择指南

超参数含义推荐范围选择依据
(秩)低秩矩阵的秩,控制表达能力8–64任务复杂度:简单格式学习用 8–16,复杂推理用 32–64
(缩放)LoRA 更新的缩放因子通常 实际缩放为 ,设为 使有效缩放为 2
target_modules应用 LoRA 的层至少 q_proj, v_proj注意力层效果最显著;加入 FFN 层可提升表达能力
lora_dropoutLoRA 层的 Dropout 率0.05–0.1数据量少时适当增大以防过拟合

💡 秩 的选择经验法则

  • 格式学习、风格迁移(任务简单):
  • Agent 工具调用学习(任务中等):
  • 数学推理、代码生成(任务复杂):
  • 不确定时:从 开始,观察验证集损失是否充分收敛

实战:基于 LoRA 的 Agent SFT 训练

环境依赖

pip install torch>=2.1.0 transformers>=4.40.0 peft>=0.10.0
pip install trl>=0.12.0 datasets accelerate bitsandbytes

步骤一:构建训练数据集

"""
Agent SFT 训练数据构建
将 Agent 交互轨迹转换为 ChatML 格式的训练样本
"""

import json
from datasets import Dataset

def build_agent_sft_dataset() -> Dataset:
    """
    构建 Agent SFT 训练数据集
    
    每条样本包含完整的多轮对话轨迹:
    系统提示 → 用户问题 → 助手推理+工具调用 → 工具结果 → 最终回答
    """
    examples = [
        {
            "instruction": "帮我查看北京今天的天气,然后告诉我应该穿什么衣服。",
            "output": (
                "<think>\n"
                "用户需要两步操作:① 获取北京实时天气数据;② 根据温度给出穿衣建议。\n"
                "先调用天气工具获取数据,再基于结果给出建议。\n"
                "</think>\n\n"
                "<tool_call>\n"
                "get_weather(city=\"北京\")\n"
                "</tool_call>"
            )
        },
        {
            "instruction": "计算从上海到北京的高铁票价,每人 553 元,3 个人来回的总费用。",
            "output": (
                "<think>\n"
                "计算公式:总费用 = 单价 × 人数 × 2(来回)= 553 × 3 × 2\n"
                "使用计算器确保精确性。\n"
                "</think>\n\n"
                "<tool_call>\n"
                "calculator(expression=\"553 * 3 * 2\")\n"
                "\</tool_call\>"
            )
        },
        # 实际训练需要 500–2,000 条覆盖各类工具调用场景的数据
    ]
    
    return Dataset.from_list(examples)


def format_to_chatml(example: dict) -> dict:
    """将单条样本格式化为 ChatML 训练格式"""
    text = (
        "<|im_start|>system\n"
        "你是一个智能助手,可以使用工具完成任务。"
        "解题时请先在 <think> 标签中写出推理过程,需要工具时使用 <tool_call> 标签。\n"
        "<|im_end|>\n"
        f"<|im_start|>user\n{example['instruction']}\n<|im_end|>\n"
        f"<|im_start|>assistant\n{example['output']}\n<|im_end|>"
    )
    return {"text": text}


dataset = build_agent_sft_dataset()
train_dataset = dataset.map(format_to_chatml)

步骤二:模型加载与 LoRA 配置

"""
模型加载与 LoRA 配置
采用 QLoRA(4-bit 量化 + LoRA)进一步降低显存需求
"""

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, TaskType

# ── 模型选择 ──────────────────────────────────────────────────────────────
model_name = "Qwen/Qwen2.5-7B-Instruct"

# ── QLoRA:4-bit 量化配置 ─────────────────────────────────────────────────
# QLoRA [3] 的核心思想:将模型权重量化为 4-bit 存储,但计算时反量化为 bfloat16
# 
# NF4(NormalFloat4)量化原理:
#   预训练权重近似服从正态分布 N(0, σ²)
#   NF4 将正态分布的值域划分为 16 个等概率区间(而非等间距区间)
#   这使得量化误差在统计意义上最小(信息论最优)
#   相比均匀量化(INT4),NF4 在相同 bit 数下精度损失更小
#
# 双重量化(Double Quantization):
#   量化过程本身需要存储量化常数(scale factor),每 64 个参数共享一个 float32 常数
#   双重量化将这些量化常数再次量化为 8-bit,额外节省约 0.4 bits/param
#   对 7B 模型,双重量化额外节省约 0.4 × 7B / 8 ≈ 350 MB 显存
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",           # NormalFloat4:信息论最优的 4-bit 量化格式
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,       # 对量化常数再次量化,额外节省 ~0.4 bits/param
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# ── LoRA 配置 ─────────────────────────────────────────────────────────────
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=32,                                # Agent 任务建议 r=32,平衡表达能力与参数效率
    lora_alpha=64,                       # alpha = 2r,有效缩放因子为 2
    lora_dropout=0.05,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",   # 注意力投影层(必选)
        "gate_proj", "up_proj", "down_proj",        # FFN 层(可选,提升表达能力)
    ],
    bias="none",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# 示例输出:trainable params: 83,886,080 || all params: 7,615,684,608 || trainable%: 1.10%

步骤三:训练配置与执行

"""
SFT 训练执行
关键超参数的选择依据均在注释中说明
"""

from transformers import TrainingArguments
from trl import SFTTrainer

training_args = TrainingArguments(
    output_dir="./checkpoints/sft",

    # ── 训练超参数 ──────────────────────────────────────────────────────────
    num_train_epochs=3,                  # Agent 格式学习通常 2–3 个 epoch 即可收敛
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,       # 有效 batch size = 4 × 4 = 16
    learning_rate=2e-4,                  # LoRA 推荐学习率范围:1e-4 ~ 5e-4
    warmup_ratio=0.1,                    # 前 10% 步数线性 warmup,防止训练初期不稳定
    weight_decay=0.01,                   # L2 正则化,防止过拟合

    # ── 精度与性能优化 ──────────────────────────────────────────────────────
    bf16=True,                           # bfloat16 混合精度:比 fp16 数值更稳定
    gradient_checkpointing=True,         # 以重计算换显存:显存减少 ~60%,速度降低 ~20%

    # ── 学习率调度 ──────────────────────────────────────────────────────────
    lr_scheduler_type="cosine",          # 余弦退火:比线性衰减通常效果更好

    # ── 日志与检查点 ────────────────────────────────────────────────────────
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    load_best_model_at_end=True,         # 训练结束后自动加载验证集最优检查点

    report_to="tensorboard",
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    max_seq_length=2048,
    dataset_text_field="text",
)

print("🚀 开始 SFT 训练...")
trainer.train()

# 保存 LoRA 适配器权重(仅保存增量参数,约 100–300 MB)
model.save_pretrained("./checkpoints/sft-lora")
tokenizer.save_pretrained("./checkpoints/sft-lora")
print("✅ LoRA 权重已保存至 ./checkpoints/sft-lora")

步骤四:权重合并与推理验证

"""
将 LoRA 适配器权重合并回基座模型
合并后的模型与原始模型结构完全相同,可直接用于推理部署
"""

from peft import PeftModel

# 加载基座模型(全精度,用于合并)
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# 加载 LoRA 适配器并合并
# merge_and_unload() 将 ΔW = BA 合并到 W₀ 中,恢复标准模型结构
model = PeftModel.from_pretrained(base_model, "./checkpoints/sft-lora")
model = model.merge_and_unload()

# 推理验证
prompt = (
    "<|im_start|>system\n你是一个数学助手,可以使用 calculator 工具进行精确计算。\n<|im_end|>\n"
    "<|im_start|>user\n计算 17 的平方根,精确到小数点后 4 位\n<|im_end|>\n"
    "<|im_start|>assistant\n"
)

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.1, do_sample=True)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(response)

训练过程中的常见问题与诊断

问题一:训练损失不下降

# 系统性诊断检查清单
diagnostics = {
    "学习率设置": {
        "症状": "损失从第一步就几乎不变",
        "原因": "学习率过小(< 1e-5)或过大(> 1e-3)",
        "解决": "尝试 2e-4,使用学习率查找器(LR Finder)确定最优值",
    },
    "数据格式": {
        "症状": "损失下降但模型输出格式混乱",
        "原因": "tokenizer 的 special tokens 未正确设置",
        "解决": "检查 pad_token、eos_token 是否正确,验证 ChatML 模板",
    },
    "LoRA 目标层": {
        "症状": "损失下降极慢",
        "原因": "target_modules 与模型架构不匹配",
        "解决": "打印 model.named_modules() 确认层名称",
    },
    "梯度消失": {
        "症状": "损失先下降后停滞",
        "原因": "gradient_checkpointing 与某些层不兼容",
        "解决": "临时关闭 gradient_checkpointing 验证",
    },
}

问题二:过拟合

# 过拟合信号:训练损失持续下降,验证损失在某点后开始上升
solutions = {
    "增加正则化": "lora_dropout 从 0.05 提升至 0.1",
    "减少训练轮数": "使用 early stopping,通常 2–3 个 epoch 足够",
    "增加数据多样性": "用 GPT-4 改写现有数据,增加表达多样性",
    "降低秩 r": "r 过大容易过拟合,从 32 降至 16 尝试",
}

问题三:显存不足(OOM)

# 显存优化策略(按效果从强到弱排列)
memory_optimizations = [
    ("4-bit 量化 QLoRA",        "显存减少 ~75%,精度损失极小"),
    ("gradient_checkpointing",  "显存减少 ~60%,速度降低 ~20%"),
    ("降低 max_seq_length",     "显存与序列长度平方成正比"),
    ("减小 batch_size",         "最基本的方法,配合 gradient_accumulation"),
    ("降低 LoRA r 值",          "r 减半,LoRA 参数量减半"),
    ("CPU offload",             "将优化器状态卸载到 CPU,速度大幅降低"),
]

📌 工程实践要点

  • 数据准备时间 >> 训练时间:实际项目中,80% 的时间花在数据收集、清洗和验证上
  • 数据可执行性验证:每条训练数据的工具调用格式应通过静态解析验证,确保参数合法
  • A/B 测试:SFT 模型上线前,务必与 prompt-only 基线进行对照实验
  • 显存估算公式:QLoRA 4-bit 下,7B 模型训练约需 10–12 GB 显存(含梯度和优化器状态)
  • 版本管理:使用 wandb 或 TensorBoard 记录每次训练的超参数、损失曲线和评估指标

SFT 毕业标准:何时从 SFT 转入 RL?

10.1 节 中我们了解到 Agentic-RL 遵循 SFT → RL 的两阶段范式。但一个关键的实操问题是:SFT 训练到什么程度,才可以"毕业",进入 RL 阶段?

这不是一个可以靠直觉回答的问题——过早进入 RL,模型连基本格式都不稳定,RL 阶段的探索效率极低;过晚进入 RL,模型过拟合 SFT 数据,多样性丧失,RL 同样无法有效探索。两种情况都会导致最终性能不佳。

SFT 毕业标准决策流程

核心原则:SFT 的目标不是"做到最好",而是"做到足够好"

这是最重要的直觉。SFT 在整个训练流程中的定位是策略初始化——它不需要(也不应该)追求最高准确率。原因很简单:

SFT 的能力上界 = 训练数据的质量上界。超越这个上界是 RL 的工作。

如果 SFT 阶段训练过度(太多 epoch、学习率不够小),模型会紧紧"贴合"训练数据的分布,导致:

  1. 多样性坍缩:模型对所有类似问题都生成几乎相同的回答
  2. 探索空间被压缩:RL 阶段需要模型在 temperature > 0 时产生有差异的多个候选回答(这是 GRPO 组内比较的前提),但过拟合的模型输出方差极小
  3. KL 惩罚的锚点过紧:RL 训练中的 (即 SFT 模型)如果本身过于确信,策略稍有偏离就会产生巨大的 KL 惩罚,限制了 RL 的优化空间

四大必检指标

实践中,我们通过以下四个指标来判断 SFT 是否"毕业":

指标一:训练/验证 Loss 收敛

这是最基础的指标。观察 TensorBoard 中的 loss 曲线:

✅ 毕业信号:验证 Loss 连续 2–3 个 evaluation step 不再下降(或下降幅度 < 0.01)
❌ 警告信号:训练 Loss 持续下降但验证 Loss 开始上升 → 过拟合,应立即停止

典型数值参考(因任务和数据而异,仅供参考):

任务类型SFT Loss 企稳区间说明
数学推理(GSM8K)0.3 – 0.6格式简单,收敛较快
工具调用(多工具)0.4 – 0.8JSON 格式增加了序列复杂度
代码生成0.5 – 1.0代码的 token 分布更分散
多轮对话 Agent0.4 – 0.9轨迹长度差异大

⚠️ 绝对数值没有统一标准。Loss 的具体值取决于 tokenizer、数据格式和序列长度。关键是观察趋势(是否企稳)和训练/验证的差距(是否过拟合),而非追求某个绝对数值。

# 判断 Loss 是否收敛的简单方法
def is_loss_converged(eval_losses: list[float], patience: int = 3, min_delta: float = 0.01) -> bool:
    """
    检查最近 patience 次评估的 loss 是否不再显著下降。
    
    Args:
        eval_losses: 历次验证集 loss 列表(按时间顺序)
        patience: 观察窗口大小
        min_delta: 最小改善阈值
    
    Returns:
        True 表示已收敛,可以考虑进入 RL
    """
    if len(eval_losses) < patience + 1:
        return False
    
    recent = eval_losses[-patience:]
    baseline = eval_losses[-(patience + 1)]
    
    # 最近几次评估都没有显著改善
    return all(baseline - loss < min_delta for loss in recent)


def is_overfitting(train_loss: float, val_loss: float, threshold: float = 0.15) -> bool:
    """
    检查是否出现过拟合(Train-Val Loss 差距过大)。
    
    Args:
        train_loss: 当前训练 loss
        val_loss: 当前验证 loss
        threshold: 允许的最大差距
    """
    return (val_loss - train_loss) > threshold

指标二:格式正确率 ≥ 90%

SFT 最核心的作用是让模型学会 Agent 的行为格式。用一个简单的验证脚本检查:

import re
import json

def check_format_accuracy(model, tokenizer, eval_prompts: list[str], num_samples: int = 100) -> dict:
    """
    评估 SFT 模型的格式正确率。
    
    Returns:
        包含各维度格式正确率的字典
    """
    results = {"think_tag": 0, "tool_call_format": 0, "json_parseable": 0, "total": num_samples}
    
    for prompt in eval_prompts[:num_samples]:
        output = generate(model, tokenizer, prompt, temperature=0.1)
        
        # 检查 <think> 标签配对
        if "<think>" in output and "</think>" in output:
            results["think_tag"] += 1
        
        # 检查 <tool_call> 格式
        tool_match = re.search(r"<tool_call>\s*(\w+)\((.*?)\)\s*</tool_call>", output, re.DOTALL)
        if tool_match:
            results["tool_call_format"] += 1
            # 检查参数是否可解析
            try:
                # 尝试解析为合法的函数调用参数
                params_str = tool_match.group(2)
                # 简单验证:参数格式合法性
                if "=" in params_str:  # key=value 格式
                    results["json_parseable"] += 1
            except Exception:
                pass
    
    # 计算综合格式正确率
    results["format_accuracy"] = results["think_tag"] / num_samples
    results["tool_call_accuracy"] = results["tool_call_format"] / num_samples
    
    return results

# 使用示例
# format_stats = check_format_accuracy(model, tokenizer, eval_prompts)
# print(f"格式正确率: {format_stats['format_accuracy']:.1%}")
# print(f"工具调用率: {format_stats['tool_call_accuracy']:.1%}")
# 
# 毕业标准:format_accuracy >= 0.9 且 tool_call_accuracy >= 0.85

为什么 90% 而非 100%? 因为 RL 阶段的奖励函数会进一步强化正确格式,少量格式错误在 RL 中会自然被"惩罚"掉。但如果格式正确率低于 90%,意味着模型尚未稳定掌握 Agent 行为格式,RL 阶段会浪费大量探索预算在格式修正上,而非策略优化。

指标三:任务准确率超过随机水平

SFT 模型不需要很高的准确率,但需要有意义地超过基座模型的 zero-shot 表现

def evaluate_task_accuracy(
    model, tokenizer, eval_dataset, 
    answer_extractor,     # 从模型输出中提取答案的函数
    answer_checker,       # 检查答案正确性的函数
) -> float:
    """
    评估 SFT 模型在目标任务上的准确率。
    """
    correct = 0
    total = len(eval_dataset)
    
    for sample in eval_dataset:
        output = generate(model, tokenizer, sample["prompt"], temperature=0.1)
        predicted = answer_extractor(output)
        if answer_checker(predicted, sample["answer"]):
            correct += 1
    
    return correct / total

# 毕业标准参考(以 GSM8K 为例)
# - 基座模型 zero-shot: ~15-20%(1.5B 模型)
# - SFT 毕业线: >= 30%(显著超过基座模型)
# - RL 目标: 50-70%(在 SFT 基础上进一步提升)
#
# 关键判断:SFT 准确率不需要很高,但必须 > 基座模型
# 这表明模型已学会任务的基本解题模式
模型规模基座 zero-shotSFT 毕业线RL 目标
1.5B10–20%≥ 25–30%45–65%
7B25–40%≥ 40–50%60–80%
14B+40–55%≥ 55–65%70–85%

这些数值以 GSM8K 数学推理为例,其他任务需根据实际情况调整。核心标准是:SFT 后准确率应当显著高于基座模型 zero-shot

指标四:输出多样性仍然存在

这是最容易被忽略但对 RL 成功至关重要的指标。GRPO 的核心机制是对同一输入采样 个回答,通过组内比较学习哪些策略更好。如果 SFT 过拟合导致模型输出几乎完全相同(多样性坍缩),GRPO 的组内标准化将失效。

def check_output_diversity(
    model, tokenizer, 
    test_prompts: list[str], 
    num_samples_per_prompt: int = 8, 
    temperature: float = 0.7,
) -> dict:
    """
    检测模型输出的多样性。
    
    核心思路:对同一 prompt 多次采样,计算输出之间的差异程度。
    """
    diversity_scores = []
    
    for prompt in test_prompts[:20]:  # 抽样 20 个 prompt
        outputs = []
        for _ in range(num_samples_per_prompt):
            output = generate(model, tokenizer, prompt, temperature=temperature)
            outputs.append(output)
        
        # 方法 1:计算不同输出的数量占比
        unique_outputs = len(set(outputs))
        diversity_ratio = unique_outputs / num_samples_per_prompt
        
        # 方法 2:计算平均 token 级别差异(更细粒度)
        # 这里简化为字符串不同比例
        pairs_different = sum(
            1 for i in range(len(outputs)) 
            for j in range(i + 1, len(outputs)) 
            if outputs[i] != outputs[j]
        )
        total_pairs = num_samples_per_prompt * (num_samples_per_prompt - 1) / 2
        pair_diversity = pairs_different / total_pairs if total_pairs > 0 else 0
        
        diversity_scores.append({
            "unique_ratio": diversity_ratio,
            "pair_diversity": pair_diversity,
        })
    
    avg_unique = sum(d["unique_ratio"] for d in diversity_scores) / len(diversity_scores)
    avg_pair = sum(d["pair_diversity"] for d in diversity_scores) / len(diversity_scores)
    
    return {
        "avg_unique_ratio": avg_unique,    # 期望 >= 0.5(8 次采样至少 4 种不同输出)
        "avg_pair_diversity": avg_pair,    # 期望 >= 0.6
        "is_healthy": avg_unique >= 0.5 and avg_pair >= 0.6,
    }

# 毕业标准:
# - avg_unique_ratio >= 0.5(至少一半的采样结果是不同的)
# - avg_pair_diversity >= 0.6(大多数输出对之间存在差异)
# 
# 如果多样性过低,说明 SFT 过拟合,需要:
# 1. 减少训练 epoch(从 3 降到 2)
# 2. 增大 lora_dropout
# 3. 使用 early stopping

完整毕业检查清单

将以上四个指标整合为一个实用的检查函数:

def sft_graduation_check(
    model, tokenizer, eval_dataset, eval_losses: list[float],
    train_loss: float, val_loss: float,
) -> dict:
    """
    SFT 毕业检查:综合评估模型是否准备好进入 RL 阶段。
    
    Returns:
        包含各项检查结果和最终毕业建议的字典
    """
    report = {}
    
    # 检查 1:Loss 收敛
    report["loss_converged"] = is_loss_converged(eval_losses, patience=3)
    report["not_overfitting"] = not is_overfitting(train_loss, val_loss)
    
    # 检查 2:格式正确率
    format_stats = check_format_accuracy(model, tokenizer, eval_dataset)
    report["format_ok"] = format_stats["format_accuracy"] >= 0.9
    
    # 检查 3:任务准确率(需要自定义 extractor 和 checker)
    # accuracy = evaluate_task_accuracy(model, tokenizer, eval_dataset, ...)
    # report["accuracy_ok"] = accuracy > baseline_accuracy
    
    # 检查 4:多样性
    diversity = check_output_diversity(model, tokenizer, eval_dataset[:20])
    report["diversity_ok"] = diversity["is_healthy"]
    
    # 综合判断
    all_checks = [report["loss_converged"], report["not_overfitting"], 
                  report["format_ok"], report["diversity_ok"]]
    report["ready_for_rl"] = all(all_checks)
    
    # 生成建议
    if report["ready_for_rl"]:
        report["recommendation"] = "✅ SFT 毕业!保存检查点作为 π_ref,可以开始 RL 训练。"
    else:
        issues = []
        if not report["loss_converged"]:
            issues.append("Loss 尚未收敛,继续训练或调整学习率")
        if not report["not_overfitting"]:
            issues.append("出现过拟合,减少 epoch 或增加正则化")
        if not report["format_ok"]:
            issues.append("格式正确率不足 90%,检查数据格式和 special tokens")
        if not report["diversity_ok"]:
            issues.append("输出多样性不足,减少训练步数或增大 dropout")
        report["recommendation"] = "❌ 暂不建议进入 RL。需解决:\n" + "\n".join(f"  - {i}" for i in issues)
    
    return report

SFT 过度训练的真实代价

为了让读者对"SFT 过度"有更直觉的感受,我们来看一组对比实验数据(基于 Qwen2.5-1.5B + GSM8K 的典型表现):

SFT 训练 EpochSFT 准确率多样性指标GRPO 后准确率说明
1 epoch25%0.8552%SFT 不足,格式不稳定,RL 需额外探索格式
2 epoch38%0.7263%最佳平衡点:格式稳定 + 足够多样性
3 epoch42%0.5858%多样性下降,RL 探索受限
5 epoch44%0.3149%严重过拟合,RL 几乎无法改善
10 epoch45%0.1243%灾难性过拟合,RL 后性能反降

⚠️ 以上数据为说明性示例,具体数值因模型、数据和超参数而异。但趋势是一致的:SFT 训练过度会严重损害 RL 阶段的效果

关键发现:

  • SFT 2 epoch 的模型虽然准确率低于 5 epoch(38% vs 44%),但 GRPO 后反而更高(63% vs 49%)——因为 2 epoch 的模型保留了足够的探索空间
  • SFT 10 epoch 后,GRPO 不仅没有提升,反而让准确率下降了(45% → 43%)——过拟合的 成为了过紧的约束

实践建议总结

建议详情
Epoch 数2–3 个 epoch 是安全区间,超过 3 个需要格外警惕过拟合
Early Stopping以验证 Loss 为基准,patience=3 eval steps
保存多个检查点保存 epoch 1/2/3 的检查点,RL 阶段可以回退到多样性更好的版本
先做 RL 预实验用 SFT 检查点做小规模 GRPO(100–200 步),观察 mean_reward 是否上升
不要追求 SFT SOTASFT 的最高准确率 ≠ RL 后的最高准确率

📌 核心记忆点

SFT 就像驾校里的科目一和科目二——它教你基本的交通规则和操作规范。你不需要在驾校里成为赛车手,只需要掌握基本功,然后到真实道路上(RL 阶段)通过实际驾驶经验来提升驾驶水平。在驾校练太久反而会形成"应试驾驶"的僵化模式,上了路不知道如何灵活应变。


SFT 阶段让模型习得了 Agent 行为的基本格式与工具调用模式。然而,SFT 的能力上界受限于训练数据的质量——模型无法超越示范数据的水平。下一节介绍的 GRPO 算法将通过强化学习信号,引导模型探索出超越示范数据的最优策略。


参考文献

[1] ZHOU C, LIU P, XU P, et al. LIMA: Less is more for alignment[C]//Advances in Neural Information Processing Systems (NeurIPS). 2023.

[2] HU E J, SHEN Y, WALLIS P, et al. LoRA: Low-rank adaptation of large language models[C]//International Conference on Learning Representations (ICLR). 2022.

[3] DETTMERS T, PAGNONI A, HOLTZMAN A, et al. QLoRA: Efficient finetuning of quantized language models[C]//Advances in Neural Information Processing Systems (NeurIPS). 2023.