跳转至

Lecture 17: Alignment with RL II

约 4494 个字 166 行代码 预计阅读时间 17 分钟

Outline

上一讲概览了 RLVR 以及 PPO、GRPO 的宏观概念,这一讲深入到策略梯度/Policy Gradient 的底层数学机制与代码级实现。从语言模型 RL 的基本设定出发,推导朴素策略梯度及其方差问题,引入基线/Baseline 和优势函数/Advantage Function 进行方差缩减,最后通过 GRPO 的代码级拆解展示从理论到工程的落地。

1. RL Settings and Basics in Language Models

1.1 RL Settings in Language Models

要在语言模型中使用强化学习,我们必须对 RL 的核心要素进行精确定义:

  • 状态/State \(s\):由输入的 Prompt 加上目前为止已经生成的回复内容组成。
  • 动作/Action \(a\):生成的下一个 Token。由于我们讨论的是 Outcome Reward(只在整个回复结束后一次性给分),为了符号简洁,后续推导中用 \(a\) 指代一整个完整生成的序列。
  • 奖励/Reward \(R\):一个标量,衡量回复的质量。这一讲聚焦于结果奖励/Outcome Rewards(只看最终结果)与可验证奖励/Verifiable Rewards(由确定性规则/代码给分,无需人类标注)。所以传统强化学习里面的折扣和自举/Bootstrapping 在这里不太适用。
  • 转移动态/Transition Dynamics \(T(s' \mid s, a)\):下一个状态等于当前状态拼接上刚生成的 Token,即 \(s' = s + a\),这里面的 \(a\) 指的是自回归生成的下一个 Token。这样简单且确定的转移动态就允许我们进行规划/Planning 或者 Test-time Compute。
  • 策略/Policy \(\pi(a \mid s)\):就是语言模型本身。
  • 轨迹/Rollout/Episode/Trajectory\(s \to a \to \cdots \to a \to R\),即从 Prompt 出发,逐 Token 生成直到结束,最后获得奖励。
  • 目标/Objective:最大化期望奖励 \(\mathbb{E}[R]\),期望在 Prompt 分布和模型生成的回复上取。

LLM 上的 RL 和传统 RL 或者面向控制的 RL 有一个巨大差异:在机器人学中,拥有一个完美的世界转移模型是奢求,很多状态受物理定律限制不可达。但在语言模型中,状态是完全 Made Up 的,意思是模型可以通过随便输出几个 Token 就达到任何状态。这赋予了 LM 极大的自由度,它甚至可以自己生成一个草稿本/Scratchpad 来进行推理时计算(Test-time Compute,如 OpenAI o1)。挑战不在于“能否达到某个状态”,而在于“如何确保模型写下的字符最终导向正确答案”。

1.2 Naive Policy Gradient

寻找最优策略 \(\pi\) 的目标是最大化期望奖励:

\[ \mathbb{E}[R] = \int p(s) \pi_\theta(a \mid s) R(s, a) \, \mathrm{d}s \, \mathrm{d}a \]

对策略参数 \(\theta\) 求梯度,利用经典的对数求导技巧(\(\nabla \pi = \pi \nabla \log \pi\)),可以得到策略梯度定理/Policy Gradient Theorem

\[ \begin{aligned} \nabla_\theta \mathbb E[R] &= \nabla_\theta \int p(s) \pi_\theta(a \mid s) R(s, a) \, \mathrm{d}s \, \mathrm{d}a \\ &= \int p(s) \nabla_\theta \pi_\theta(a \mid s) R(s, a) \, \mathrm{d}s \, \mathrm{d}a \\ &= \int p(s) \pi_\theta(a \mid s) \nabla_\theta \log \pi_\theta(a \mid s) R(s, a) \, \mathrm{d}s \, \mathrm{d}a \\ &= \mathbb E_{s, a} \left[ \nabla_\theta \log \pi_\theta(a \mid s) \cdot R(s, a) \right] \end{aligned} \]

朴素策略梯度/Naive Policy Gradient 的操作流程符合直觉:采样 Prompt \(s\),从当前策略采样回复 \(a \sim \pi(a \mid s)\),然后基于 \(\nabla \log \pi(a \mid s) \cdot R(s, a)\) 更新参数。这和 SFT 极其相似——SFT 最大化人类给定的正确回复的概率,而朴素策略梯度则是自己采样生成 \(a\),用奖励 \(R\) 加权做梯度更新。

\(R(s, a) \in \{0, 1\}\) 的设定下(回答正确得 1 分,错误得 0 分),朴素策略梯度只在正确回复上更新参数,\(R = 1\) 时正常做一次梯度更新,\(R = 0\) 时梯度为零,直接忽略。这看起来像 SFT,但有一个关键区别:数据集是随时间变化的。在 SFT 中数据是固定的,而在 RL 中目标函数依赖于当前策略 \(\pi\) 的采样。当你更新了参数后,下一轮采样时模型给出的回答已经变了,相当于面对的数据集已经完全不同。只要模型能答对哪怕一点简单题,策略就会变好,下一轮就能生成更多高回报的数据。

想象让模型解一道极难的数学题,奖励极其稀疏。如果初始模型很差,生成的回答全是错的,此时 \(R = 0\),代入公式会发现梯度为零,模型完全无法进行任何参数更新。这种 得不到奖励就不知道怎么更新 的问题,是 RL 区别于监督学习最大的痛点。

另一个问题是高方差:即使有些回复能获得非零奖励,不同回复的奖励值差异巨大,导致梯度估计的方差极高,收敛缓慢。作为对比,在 RLHF 中,Reward Model 从成对偏好中学到的奖励信号更连续,缓解了稀疏性。

一个自然的想法是:如果奖励为 0 时没有梯度,那把错误回复的奖励设为 \(-1\) 岂不是能产生 向错误方向反向推开 的梯度?这个直觉是对的,但粗暴地手动设置负值是一种 hack。接下来的基线方法会在数学上自动且完美地完成这个任务。

1.3 Baseline & Advantage Function

为了解决朴素策略梯度高方差、容易卡死的问题,我们引入基线 \(b(s)\),把优化目标修改为最大化带基线的奖励:

\[ \nabla_\theta \mathbb E[R] = \mathbb E \left[ \nabla_\theta \log \pi_\theta(a \mid s) \cdot (R(s, a) - b(s)) \right] \]

只要 \(b(s)\) 仅依赖于状态 \(s\) 而不依赖于具体的动作 \(a\),这种减法在数学期望上是无偏的,因为 \(\mathbb E_a [b(s) \nabla \log \pi(a \mid s)] = b(s) \nabla \sum_a \pi(a \mid s) = b(s) \nabla 1 = 0\)。这就是说 \(\mathbb E[R]\) 只是被一个不依赖于策略 \(\pi\) 的常数 \(\mathbb E[b(s)]\) 平移了,梯度的期望不变,但方差可以大幅降低,可以通过数值方法和数学证明进行证明。

那么 \(b(s)\) 应该怎么设?理论上最优基线的公式很复杂(对于单参数模型是 \(b^*(s) = \mathbb E[(\nabla \pi(a \mid s))^2 R] / \mathbb E[(\nabla \pi(a \mid s))^2]\)),实际中采用的启发式基线是该状态的期望奖励,即价值函数:

\[ b(s) = V(s) = \mathbb E[R \mid s] \]

这和优势函数有天然的联系。定义:

  • \(V(s) = \mathbb E[R \mid s]\):从状态 \(s\) 出发的期望奖励。
  • \(Q(s, a) = \mathbb E[R \mid s, a]\):在状态 \(s\) 采取动作 \(a\) 后的期望奖励。注意,由于我们讨论的是 Outcome Reward 且 \(a\) 代表完整回复,\(Q(s, a) = R(s, a)\)
  • 优势函数/Advantage\(A(s, a) = Q(s, a) - V(s)\)

优势函数的直觉是:动作 \(a\) 比从状态 \(s\) 出发的平均表现好多少。当 \(b(s) = V(s)\) 时,带基线的奖励 \(R(s, a) - b(s)\) 恰好就是优势函数 \(A(s, a)\)

综合来看,策略梯度的一般形式是:

\[ \nabla_\theta \mathbb E[R] \approx \mathbb E \left[ \nabla_\theta \log \pi_\theta(a \mid s) \cdot \delta \right] \]

\(\delta\) 是某种优势估计,其具体计算方式正是区分不同算法的关键。

2. Training Walkthrough

这一节通过一个玩具任务——给数字序列排序(输入 [1, 0, 2],期望输出 [0, 1, 2])——将上述理论落地到 GRPO 的具体代码实现中。

回顾上一讲:GRPO 的核心思想是利用语言模型特有的组结构,也就是同一个 Prompt 可以采样 \(G\) 个不同的回答,来计算组内基线,从而摆脱 PPO 中臃肿的 Value Model。完整的 GRPO 训练流程如下:

  1. 生成回复/Generate Responses:对每个 Prompt 采样 \(G\) 个回答;
  2. 计算奖励 \(R\) 和优势估计 \(\delta\)
  3. 计算回复的 log 概率;
  4. 根据 log 概率和 \(\delta\) 计算损失并更新参数。
Code for Model
class Model(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, prompt_length: int, response_length: int):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.encode_weights = nn.Parameter(torch.randn(prompt_length, embedding_dim, embedding_dim)/ math.sqrt(embedding_dim))
        self.decode_weights = nn.Parameter(torch.randn(response_length, embedding_dim, embedding_dim)/ math.sqrt(embedding_dim))

    def forward(self, prompts: torch.Tensor) -> torch.Tensor:
        """
        Args:
            prompts: int[bat, pos]
        Returns:
            logits: float[bat, pos, vocab]
        """            
        embeddings = self.embedding(prompts)  # [bat, pos, dim]

        encoded = einsum(embeddings, self.encode_weights, "bat pos dim1, pos dim1 dim2 -> bat dim2")  # [bat, dim]
        decoded = einsum(encoded, self.decode_weights, "bat dim2, pos dim2 dim1 -> bat pos dim1")  # [bat, pos, dim]
        logits  = einsum(decoded, self.embedding.weight, "bat pos dim, vocab dim -> bat pos vocab")  # [bat, pos, vocab]

        return decoded  # [bat, pos, vocab]

2.1 Sampling Responses & Calculating Probabilities

生成回答的代码如下:

def sample_responses(model: Model, prompts: torch.Tensor, num_samples: int) -> torch.Tensor:
    """
    Args:
        prompts: int[bat, pos]
    Returns:
        responses: int[bat, trail, pos]
    """
    logits = model(prompts)  # [bat, pos, vocab]
    batch_size = logits.shape[0]

    # sample _num_samples_ responses for each prompt
    flattened_logits = rearrange(logits, "bat pos vocab -> (bat pos) vocab")  # [bat*pos, vocab]
    flattened_responses = torch.multinomial(F.softmax(flattened_logits, dim=-1), num_samples, replacement=True)  # [bat*pos, num_samples]
    responses = rearrange(flattened_responses, "(bat pos) num_samples -> bat trail pos", bat=batch_size)  # [bat, num_samples, pos]

    return responses

对于一批提示词 prompts,模型输出一批固定长度的 logits [bat, pos, vocab],其中后面的 [pos, vocab] 部分里面的第 i 行第 j 列代表在第 i 个位置生成的 token 为词汇表里面的第 j 个 token 的概率。然后我们在选择哪个 token 这一个级别进行采样。

def compute_log_probs(prompts: torch.Tensor, responses: torch.Tensor, model: Model) -> torch.Tensor:
    logits = model(prompts) # [batch, pos, vocab]
    log_probs = F.log_softmax(logits, dim=-1) # [batch, pos, vocab]

    num_responses = responses.shape[1]
    log_probs = repeat(log_probs, "batch pos vocab -> batch trial pos vocab", trial=num_responses) # [batch, trial, pos, vocab]

    log_probs = log_probs.gather(
        dim=-1, index=responses.unsqueeze(-1)
    ).squeeze(-1) # [batch, trial, pos]

    return log_probs

首先计算出 logits 的 log 概率(当然是对 vocab 维度做 softmax),然后根据采样得到的 responses 中每个 token 的索引,在 log 概率矩阵中对应位置取出该 token 的 log 概率。

2.2 Reward Shaping

初始的模型基本什么也排不对,这就像 DeepSeek R1-zero 冷启动的时候一样,基本完不成任务。因此必须设计部分得分,否则奖励全是 0,策略梯度完全失效。对于排序任务,我们可以设计下面的奖励规则:

def compute_reward(prompts: torch.Tensor, responses: torch.Tensor, reward_fn: Callable[[list[int], list[int]], float]) -> torch.Tensor:
    """
    Args:
        prompts: int[batch pos]
        responses: int[batch trial pos]
    Returns:
        rewards: float[batch trial]
    """
    batch_size, num_responses, _ = responses.shape
    rewards = torch.empty(batch_size, num_responses, dtype=torch.float32)
    for i in range(batch_size):
        for j in range(num_responses):
            rewards[i, j] = reward_fn(prompts[i, :], responses[i, j, :])
    return rewards

def sort_distance_reward(prompt: list[int], response: list[int]) -> float:
    """
    Return how close response is to ground_truth = sorted(prompt).
    In particular, compute number of positions where the response matches the ground truth.
    """
    assert len(prompt) == len(response)
    ground_truth = sorted(prompt)
    return sum(1 for x, y in zip(response, ground_truth) if x == y)

def sort_inclusion_ordering_reward(prompt: list[int], response: list[int]) -> float:
    """
    Return how close response is to ground_truth = sorted(prompt).
    In particular, compute two components:
     1. Inclusion reward: number of tokens in prompt that show up in response.
     2. Ordering reward: number of adjacent pairs in response that are in sorted order.
    """
    assert len(prompt) == len(response)

    # Give one point for each token in the prompt that shows up in the response
    inclusion_reward = sum(1 for x in prompt if x in response)

    # Give one point for each adjacent pair in response that's sorted
    ordering_reward = sum(1 for x, y in zip(response, response[1:]) if x <= y)

    return inclusion_reward + ordering_reward

sort_distance_reward 是最简单的:逐位置比较回复和 ground truth,匹配一个得一分。sort_inclusion_ordering_reward 则给更多的部分得分:包含正确元素得分 + 相邻升序对得分。后者对初始模型更友好,因为即使排序不完全正确也能拿到分。

但是部分得分是一把双刃剑:上述奖励规则存在漏洞/Loophole:一个什么都不会的模型可能为了拿"升序相邻分",不管三七二十一地输出一堆递增的长串。所以奖励设计需要非常谨慎。

2.3 Computing Loss & KL Penalty

在对同一个 Prompt 采样出 \(G\) 个不同回答并计算出奖励 \(\{r_1, \ldots, r_G\}\) 后,如何计算优势估计 \(\delta\)?代码中 compute_deltas 提供了四种方式:

def compute_deltas(rewards: torch.Tensor, mode: str) -> torch.Tensor:
    """
    Args:
        rewards: float[batch trial]
    Returns:
        deltas: float[batch trial] which are advantage-like quantities for updating
    """
    if mode == "rewards":
        return rewards

    if mode == "centered_rewards":
        # Compute mean over all the responses (trial) for each prompt (batch)
        mean_rewards = rewards.mean(dim=-1, keepdim=True)

        centered_rewards = rewards - mean_rewards
        return centered_rewards

    if mode == "normalized_rewards":
        mean_rewards = rewards.mean(dim=-1, keepdim=True)
        std_rewards = rewards.std(dim=-1, keepdim=True)
        centered_rewards = rewards - mean_rewards
        normalized_rewards = centered_rewards / (std_rewards + 1e-5)
        return normalized_rewards

    if mode == "max_rewards":
        # Zero out any reward that isn't the maximum for each batch
        max_rewards = rewards.max(dim=-1, keepdim=True)[0]
        max_rewards = torch.where(rewards == max_rewards, rewards, torch.zeros_like(rewards))
        return max_rewards

    raise ValueError(f"Unknown mode: {mode}")
  • "rewards"/原始奖励\(\delta_i = r_i\),最粗暴的朴素策略梯度,只有正向加强,没有负向惩罚。
  • "centered_rewards"/中心化奖励\(\delta_i = r_i - \mathrm{mean}(\{r_1, \ldots, r_G\})\)。这就是不用 Value Model 的天然基线。注意 mean 沿 dim=-1(trial 维度)计算,keepdim=True 保持广播。如果所有回答的奖励完全相同,减去均值后所有的 \(\delta\) 都是 0,模型认为谁也不比谁更好,放弃更新,这非常合理。
  • "normalized_rewards"/标准化奖励\(\delta_i = (r_i - \mathrm{mean}) / (\mathrm{std} + \epsilon)\),GRPO 原论文的做法,让奖励尺度变得无关紧要。但正如上一讲提到的,Dr. GRPO 指出除以标准差是非线性操作,会破坏基线的无偏性,当题目太简单或太难时标准差极小,给过易/过难题赋予极高梯度权重,产生病态优化。
  • "max_rewards"/赢家通吃:用 torch.where 把同组中非最高分的奖励强制清零,只保留最优回复的梯度。防止模型沉溺于部分得分而停滞不前。
  • 根据 \(\delta\) 和 log 概率计算损失,compute_loss 提供三种模式:
def compute_loss(log_probs: torch.Tensor, deltas: torch.Tensor, mode: str, old_log_probs: torch.Tensor | None = None) -> torch.Tensor:
    if mode == "naive":
        return -einsum(log_probs, deltas, "batch trial pos, batch trial -> batch trial pos").mean()

    if mode == "unclipped":
        ratios = torch.exp(log_probs - old_log_probs)  # [batch trial]
        return -einsum(ratios, deltas, "batch trial pos, batch trial -> batch trial pos").mean()

    if mode == "clipped":
        epsilon = 0.01
        unclipped_ratios = torch.exp(log_probs - old_log_probs)  # [batch trial]
        unclipped = einsum(unclipped_ratios, deltas, "batch trial pos, batch trial -> batch trial pos")

        clipped_ratios = torch.clamp(unclipped_ratios, min=1 - epsilon, max=1 + epsilon)
        clipped = einsum(clipped_ratios, deltas, "batch trial pos, batch trial -> batch trial pos")
        return -torch.minimum(unclipped, clipped).mean()

    raise ValueError(f"Unknown mode: {mode}")
  • "naive"(朴素模式)\(L = - \mathbb E [ \log \pi_\theta(a \mid s) \cdot \delta ]\)。直接用 \(\delta\) 加权 log 概率,这就是最基本的 REINFORCE。einsum[batch, trial, pos] 的 log 概率与 [batch, trial]\(\delta\) 逐位置相乘(\(\delta\) 广播到每个 pos),然后取均值。
  • "unclipped"(未裁剪模式)\(L = - \mathbb E [ \frac{\pi_\theta}{\pi_{\theta_{\mathrm{old}}}} \cdot \delta ]\)。通过 exp(log_probs - old_log_probs) 计算新旧策略的概率比值 \(r_t(\theta)\),用比值代替 log 概率加权。这允许在旧策略采样的数据上做多步更新(off-policy correction),但不做任何限制。
  • "clipped"(裁剪模式)\(L = - \mathbb E [ \min(r_t \delta, \; \mathrm{clip}(r_t, 1-\epsilon, 1+\epsilon) \delta) ]\)。标准的 PPO/GRPO 裁剪机制。torch.clamp 将比值限制在 \([1-\epsilon, 1+\epsilon]\) 内,取 minimum 确保:当 \(\delta > 0\) 时,比值不能超过 \(1+\epsilon\)(防止过度加强);当 \(\delta < 0\) 时,比值不能低于 \(1-\epsilon\)(防止过度惩罚)。

为了不让 RL 后的模型遗忘基础能力,在损失中加入 KL 散度惩罚 \(\mathrm{KL}(\pi_\theta \| \pi_{\mathrm{ref}})\)。传统的 KL 散度 \(\mathbb E_p[\log(p/q)]\) 在样本估计时方差较大。利用恒等式 \(\mathbb E_p[q(x)/p(x)] = 1\),可以推导出等价且低方差的估计式:

\[ \mathrm{KL}(p \| q) = \mathbb E_p \left[ \frac{q(x)}{p(x)} - \log \frac{q(x)}{p(x)} - 1 \right] \]

这个公式的每一项都是非负的(因为 \(t - \log t - 1 \geq 0\) 对所有 \(t > 0\) 成立),避免了传统估计中可能出现的负值波动,从而降低方差。

def compute_kl_penalty(log_probs: torch.Tensor, ref_log_probs: torch.Tensor) -> torch.Tensor:
    """
    Compute an estimate of KL(model || ref_model), where the models are given by:
        log_probs [batch trial pos vocab]
        ref_log_probs [batch trial pos vocab]
    Use the estimate:
        KL(p || q) = E_p[q/p - log(q/p) - 1]
    """
    return (
        torch.exp(ref_log_probs - log_probs)  - (ref_log_probs - log_probs) - 1
    ).sum(dim=-1).mean()

\(t = q/p = \exp(\log q - \log p)\),则 torch.exp(ref - cur) 就是 \(t\)ref - cur 就是 \(\log t\),最后减 1。.sum(dim=-1) 将各 position 上的 KL 求和得到整个序列的 KL,.mean() 在 batch 和 trial 上取平均。

2.4 Freezing Parameters

GRPO 和 PPO 都需要计算新旧策略的比值 \(\mathrm{ratio} = \pi_\theta(a \mid s) / \pi_{\theta_{\mathrm{old}}}(a \mid s)\) 来限制大幅度更新。在 PyTorch 实现中有一个关键陷阱:如果模型参数还没更新,pp_old 来自同一个模型,那么 ratio = p / p_old 恒等于 1,对参数的梯度为 0,直接破坏训练。

1
2
3
4
5
6
7
# Wrong: ratio == 1 and the gradient is 0
w = torch.tensor(2., requires_grad=True)
p = torch.nn.Sigmoid()(w)
p_old = torch.nn.Sigmoid()(w)
ratio = p / p_old
ratio.backward()
w.grad  # => 0!

正确做法是让 p_old 脱离计算图,当作不可导的常量:

1
2
3
4
5
6
7
8
# Correct:
w = torch.tensor(2., requires_grad=True)
p = torch.nn.Sigmoid()(w)
with torch.no_grad():  # Important: treat p_old as a constant!
    p_old = torch.nn.Sigmoid()(w)
ratio = p / p_old
ratio.backward()
w.grad  # => non-zero!

2.5 Putting It All Together

综合上述各个部分,run_policy_gradient 的每个 epoch 中执行以下流程:

  1. 冻结参考模型:如果启用了 KL 惩罚,每隔 compute_ref_model_period 个 epoch 对当前模型做一次 clone(),作为固定的 \(\pi_{\mathrm{ref}}\)。注意这里参考模型不是一成不变的,而是周期性更新——如果永远冻结在初始模型,随着训练推进策略偏离越来越大,KL 惩罚会主导损失,阻碍学习。
  2. 生成回复并计算奖励和 \(\delta\):用当前策略对每个 Prompt 采样 \(G\) 个回答(generate_responses),计算奖励(compute_reward),然后根据选择的 deltas_mode 计算优势估计(compute_deltas)。这三步在每个 epoch 开始时只做一次。
  3. 计算并且冻结 log 概率:如果使用 KL 惩罚,在参考模型下用 torch.no_grad() 计算 ref_log_probs;如果使用非 naive 的损失模式,同样在当前模型下用 torch.no_grad() 计算 old_log_probs。这两者都必须脱离计算图,作为不可导的常量参与后续计算。
  4. 计算损失并且多步内循环更新:对同一批采样数据做 num_steps_per_epoch 步梯度更新。每步中,在当前模型(可导)下计算 log_probs,结合冻结的 old_log_probsdeltas 计算损失,再加上 KL 惩罚项,反向传播并更新参数。这里的关键设计是:同一批 rollout 数据被复用多步,通过策略比值的裁剪机制保证每步更新不会偏离太远。这比每步都重新采样高效得多,但也正是需要 clipping 的原因。

3. Postscript

以目前的视角(2025 年上半年)看来,强化学习仍然是超越人类水平的关键方法。其关键优势与最核心优势在于,只要我们可以测量某一个度量,我们就可以优化它/If you can measure it, you can optimize it。奖励设计 是 RL 成败的关键,如果可以设计出一个不可被 hack、且能被系统验证的奖励函数,RL 就有能力找到远超人类水平的推理策略。

3.1 Don't Trust the Loss Curve

在监督学习中,训练集上的 Loss 下降代表模型在变好。但在 RL 中,Loss 曲线基本没有意义。原因在于数据集不是静态的:每次拿新参数采样出的回复集合(即数据集本身)都在变化,而 Loss 只是相对于当前模型在当前数据集上的一个指标,没有一个合适的参考点。

在使用 RL 进行训练的时候,只有平均奖励/Mean Reward 稳步上升,才是模型真正在变强的验证。在一个验证集上定期评估奖励和 Loss,观察奖励的趋势才是关键。

从实验结果来看,在排序的玩具任务上:

  • 使用原始奖励:模型很难学会排序,仍然停留在训练集上;
  • 使用中心化奖励:有所帮助。次优回复会得到负向梯度更新,如果所有回复的奖励完全相同则不更新。整体更好,但仍容易陷入局部最优;
  • 使用标准化奖励:在这个任务上差异不大,因为所有回复长度相同,不存在长度偏差问题。

总体而言,即使在玩具任务上,强化学习也不是 trivial 的,很容易卡在次优状态,超参数的调整至关重要。

3.2 RL Infra

虽然上面的玩具代码用矩阵乘法就能跑通,但真实大模型的预训练和 RL 系统的复杂度完全不在一个量级。

预训练相对简单:喂进去 Token,前向传播、反向传播,传一传梯度就结束了。但 RL 需要进行极其繁重的生成/Inference Rollout——自回归解码和巨大的 KV Cache 显存占用。并且你需要在显存中同时维护多个模型副本:

  1. 正在更新的当前策略/Current Policy
  2. 用于计算策略比值的旧策略/Old Policy
  3. 用于计算 KL 惩罚的固定参考模型/Reference Model
  4. 如果使用 PPO,还需要一个同等规模的价值模型/Critic/Value Model

这正是 GRPO 砍掉 Value Model 的实际意义所在——在工业界动辄数百亿参数的模型上,少维护一个同等大小的模型意味着巨大的显存和计算节省。

现实中,为了搞定这些,工业界不得不搭建极其复杂的分布式系统:用一批 GPU 专门做生成推理,这部分依赖 vLLM 框架,再通过网络通信将采样的长轨迹和模型权重来回传递给做梯度的训练节点。上一讲中 Kimi k1.5 的案例展示了这种 RL Infra 的实际架构。

虽然在 2025 年版本的课上对 RL Infra 上不做深入讲解,但是 RL Infra 的设计与优化已经早就成为学术界与工业界的研究热点,下面是一些重要的参考: