机器学习 | 从0开发大模型之DeepSeek的GRPO

人工智能
DeepSeek-R1的发布为国产大模型争光了(太强了),不过 GRPO 算法源自 DeepSeekMath 7B 模型,该模型在 MATH 基准测试中取得了优异成绩。GRPO 是一种在线学习算法,核心思想是通过组内相对奖励来估计基线,从而避免使用额外的价值函数模型。

DeepSeek-R1的发布为国产大模型争光了(太强了),不过 GRPO 算法源自 DeepSeekMath 7B 模型,该模型在 MATH 基准测试中取得了优异成绩,论文发表于2024年2月份:https://huggingface.co/papers/2402.03300,以下是该论文的摘要原文:

Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO.
  • 1.

翻译如下:

数学推理对语言模型构成了重大挑战,因为其复杂且结构化的特性。在本文中,我们介绍了DeepSeekMath 7B,它在DeepSeek-Coder-Base-v1.5 7B的基础上进行了继续预训练,使用了来自Common Crawl的120B与数学相关的标记,以及自然语言和代码数据。DeepSeekMath 7B在竞争级MATH基准测试中取得了51.7%的优异成绩,且未依赖外部工具包和投票技术,接近Gemini-Ultra和GPT-4的性能水平。DeepSeekMath 7B在64个样本上的自一致性达到了60.9%的MATH成绩。DeepSeekMath的数学推理能力归因于两个关键因素:首先,我们通过精心设计的数据选择流程,充分利用了公开可用的网络数据的巨大潜力。其次,我们引入了群体相对策略优化(GRPO),这是一种近端策略优化(PPO)的变体,旨在增强数学推理能力,同时优化PPO的内存使用。
  • 1.

图片

对比数据

1、什么是GRPO

GRPO 是一种在线学习算法,核心思想是通过组内相对奖励来估计基线,从而避免使用额外的价值函数模型。通过在训练期间使用受训模型自身生成的数据来迭代改进,GRPO 旨在最大化生成补全的优势,同时确保模型保持接近参考策略,下图是论文中的算法流程图:

图片

GRPO

GRPO 是 PPO (Proximal Policy Optimization,近端策略优化,是一种强化学习算法,由OpenAI于2017年提出,旨在解决策略梯度方法中的训练不稳定问题) 的变体,主要区别是:

  • GRPO 省略 value function model
  • GRPO 奖励计算,改成了一个 q 生成多个 r,然后 reward 打分

GRPO算法流程:

  • 采样一组输出并计算每个输出的奖励
  • 对组内奖励进行归一化处理
  • 使用归一化后的奖励计算优势函数
  • 通过最大化目标函数更新策略模型
  • 迭代训练,逐步优化策略模型

图片

论文中的伪代码

2、奖励设计

huggingface 库提供 GRPOTrainer 可以直接使用 GRPO 训练,参数包括定义奖励模型和函数。

2.1 奖励模型

trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-3B-Instruct",
    reward_funcs="weqweasdas/RM-Gemma-2B",
    args=training_args,
    train_dataset=dataset,
    peft_cnotallow=LoraConfig(task_type="CAUSAL_LM"),
)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.

这里的 reward_funcs 参数可以传入奖励模型。

2.2 奖励函数

GRPOTrainer 允许用户自定义奖励函数,通过定义一个返回浮点数列表的函数来实现。

  • 奖励函数的输入:completions(生成的补全)和 prompts(提示)
  • 奖励函数的输出:返回一个浮点数列表,每个浮点数代表对应于单个补全的奖励

(1)较长补全奖励函数

def completion_reward(completions, **kwargs):
    '''奖励函数,对较长的补全给予更高的分数'''
    return [float(len(completion))/100 for completion in completions]

prompts = ["The sky is", "The sun is"]
completions = [" blue.", " in the sky."]
print("completion_reward: ", completion_reward(prompts=prompts, completinotallow=completions))
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.

(2)格式正确奖励函数

def format_reward(completions, **kwargs):
    '''格式奖励'''
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, response) for response in responses]
    return [0.5if match else0.0for match in matches]

prompts = [
    [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
    [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
]
completions = [
    [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
    [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
]
print("format_reward: ", format_reward(prompts=prompts, completinotallow=completions))
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.

根据以上的奖励样例,可以设计对于不同数据集的奖励函数,如:

  • 判断内容中是否包含数字
  • 判断内容回答是否参考网页的知识库内容
  • ...

然后将这些函数传入 GRPOTrainer 即可,代码如下:

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        ...
        format_reward,
        completion_reward,
    ],
    args=training_args,
    train_dataset=data,
    ...
)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.

3、使用 GRPO 训练模型

github上已经有很多复刻 DeepSeek-R1-Zero 的方案,有兴趣可以看一下这几个开源项目(成本基本都控制在500以内):

3.1 训练代码

这里为了演示如何使用 GRPO 训练模型,本文也给出了完整的训练代码,其中流程如下:

图片

  • 使用 Qwen/Qwen2.5-3B-Instruct 作为基础模型
  • 使用 swulling/gsm8k_chinese 作为训练数据集
import re
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

SYSTEM_PROMPT = """
按照如下格式生成:
<think>
...
</think>
<answer>
...
</answer>
"""

def process_data(data):
    return data.map(
        lambda x: {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": x["question_zh-cn"]},
            ],
            "answer": x["answer_only"],
        }
    )

def extract_answer(text):
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def correctness_reward(completions, answer, **kwargs):
    responses = [completion[0]["content"] for completion in completions]
    extracted_responses = [extract_answer(r) for r in responses]
    return [2.0if response == str(ans) else0.0for response, ans in zip(extracted_responses, answer)]

def completion_reward(completions, **kwargs):
    '''奖励函数,对较长的补全给予更高的分数'''
    return [float(len(completion)) / 100for completion in completions]

prompts = ["The sky is", "The sun is"]
completions = [" blue.", " in the sky."]
print("completion_reward: ", completion_reward(prompts=prompts, completinotallow=completions))

def digit_reward(completions, **kwargs):
    responses = [completion[0]["content"] for completion in completions]
    extracted_responses = [extract_answer(r) for r in responses]
    return [0.5if response.isdigit() else0.0for response in extracted_responses]

def format_reward(completions, **kwargs):
    '''格式奖励'''
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, response) for response in responses]
    return [0.5if match else0.0for match in matches]

prompts = [
    [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
    [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
]
completions = [
    [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
    [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
]
print("format_reward: ", format_reward(prompts=prompts, completinotallow=completions))

def mark_reward(completions, **kwargs):
    '''标记奖励(改善格式奖励稀疏问题)'''
    def mark_num(text):
        reward = 0
        if text.count("<think>\n") == 1:
            reward += 0.125

        if text.count("</think>\n") == 1:
            reward += 0.125

        if text.count("<answer>\n") == 1:
            reward += 0.125

        if text.count("</answer>\n") == 1:
            reward += 0.125 * 2

        return reward

    responses = [completion[0]["content"] for completion in completions]
    return [mark_num(response) for response in responses]

if __name__ == "__main__":
    model_name = "Qwen/Qwen2.5-3B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir="./model")
    model.cuda()
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    ds = load_dataset("swulling/gsm8k_chinese", cache_dir="./dataset")
    data = process_data(ds["train"])
    output_dir = "output"
    training_args = GRPOConfig(
        output_dir=output_dir,
        learning_rate=5e-6,
        adam_beta1=0.9,
        adam_beta2=0.99,
        weight_decay=0.1,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        logging_steps=1,
        bf16=True,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        num_generatinotallow=16,
        max_prompt_length=256,
        max_completion_length=200,
        num_train_epochs=1,
        save_steps=100,
        max_grad_norm=0.1,
        log_on_each_node=False,
        use_vllm=False,
        report_to="tensorboard",
    )

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[
            mark_reward,
            format_reward,
            digit_reward,
            completion_reward,
            correctness_reward,
        ],
        args=training_args,
        train_dataset=data,
        peft_cnotallow=LoraConfig(task_type="CAUSAL_LM"),
    )
    trainer.train()
    trainer.save_model(output_dir)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
  • 88.
  • 89.
  • 90.
  • 91.
  • 92.
  • 93.
  • 94.
  • 95.
  • 96.
  • 97.
  • 98.
  • 99.
  • 100.
  • 101.
  • 102.
  • 103.
  • 104.
  • 105.
  • 106.
  • 107.
  • 108.
  • 109.
  • 110.
  • 111.
  • 112.
  • 113.
  • 114.
  • 115.
  • 116.
  • 117.
  • 118.
  • 119.
  • 120.
  • 121.
  • 122.
  • 123.
  • 124.
  • 125.
  • 126.
  • 127.
  • 128.
  • 129.
  • 130.
  • 131.
  • 132.
  • 133.
  • 134.
  • 135.

3.2 输出结果

(1)Qwen/Qwen2.5-3B-Instruct 输出:

用户: 
使用给定的数字 [44, 19, 35],创建一个等于 98 的方程。你可以使用基本算术运算(+、-、*、/)一次或多次,但每个数字只能使用一次。在 <think> </think> 标签中展示你的思考过程,并在 <answer> </answer> 标签中返回最终方程,例如 <answer> (1 + 2) / 3 </answer>。在 <think> 标签中逐步思考。

机器人:  
<span class="think">首先我们需要找到如何将给定的数字组合起来得到 98。我们可以尝试不同的组合,比如加法和减法。注意到 98 可以分解为 90 和 8 的和,因此我们可以考虑用 44 和 35 来构造这两个数。</span>
<span class="think">观察到 44 和 35 的差是 9,如果我们将 35 减去这个差,就可以得到 8。即 35 - (44 - 35) = 8。</span>
<span class="think">现在我们有了 8 和 90,可以将它们相加得到 98。所以最终的方程是 35 + (44 - 35)。</span>
<answer> 35 + (44 - 35) </answer>
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.

(2)训练的模型输出:

用户: 
使用给定的数字 [44, 19, 35],创建一个等于 98 的方程。你可以使用基本算术运算(+、-、*、/)一次或多次,但每个数字只能使用一次。在 <think> </think> 标签中展示你的思考过程,并在 <answer> </answer> 标签中返回最终方程,例如 <answer> (1 + 2) / 3 </answer>。在 <think> 标签中逐步思考。

机器人:  
<think>
首先,我需要使用数字44、19和35,通过基本算术运算(+、-、*、/)来构造一个等于98的方程,且每个数字只能使用一次。

我先考虑将较大的数字进行加法或乘法运算,以达到接近98的目标。例如,44加19等于63,再加上35可以得到98。具体计算如下:

44 + 19 = 63  
63 + 35 = 98  

因此,最终的方程可以表示为44 + 19 + 35 = 98。
</think>

<answer> 44 + 19 + 35 = 98 </answer>
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.

参考

(1)https://github.com/Jiayi-Pan/TinyZero

(2)https://github.com/huggingface/open-r1

(3)DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning(https://arxiv.org/pdf/2501.12948)

(4)https://zhuanlan.zhihu.com/p/20021693569

(5)https://zhuanlan.zhihu.com/p/19949917958

(6)https://blog.csdn.net/qq_38961840/article/details/145387854

责任编辑:庞桂玉 来源: 周末程序猿
相关推荐

2025-04-03 15:46:53

2024-11-04 00:24:56

2024-12-09 00:00:10

2024-11-26 09:33:44

2024-12-26 00:46:25

机器学习LoRA训练

2025-01-10 08:38:10

2025-02-17 10:40:20

2025-02-20 09:27:46

2025-03-11 13:07:58

2025-04-07 02:25:00

DeepSeek模型训练GRPO

2025-02-17 08:00:00

DeepSeek模型AI

2025-03-11 01:00:00

GRPO算法模型

2025-02-13 11:00:30

2017-08-25 14:05:01

机器学习算法模型

2020-10-13 07:00:00

机器学习人工智能

2022-09-06 08:00:00

机器学习金融数据科学

2025-02-20 17:19:08

2022-05-18 16:24:36

PythonPyCaret机器学习

2025-03-06 07:28:31

DeepSeek大模型人工智能

2025-02-10 09:42:14

点赞
收藏

51CTO技术栈公众号