DeepSeek级AI?训练自己的推理模型仅需七个步骤

译文 精选
人工智能
DeepSeek的R1模型在不需要人类反馈的情况下就能进行更深思熟虑的推理,已颠覆了大语言模型(LLM)领域。这一重大突破背后的关键是群体相对策略优化(GRPO),这是一种帮助模型自主开发推理能力的强化学习技术。

译者 | 布加迪

审校 | 重楼

谁需要超级计算机?仅用15GB VRAM就可以训练你自己的功能强大的AI推理模型!

DeepSeek的R1模型在不需要人类反馈的情况下就能进行更深思熟虑的推理,已颠覆了大语言模型(LLM)领域。这一重大突破背后的关键是群体相对策略优化(GRPO),这是一种帮助模型自主开发推理能力的强化学习技术。与依赖值函数的近端策略优化(PPO)不同,GRPO在不需要值函数的情况下就可以优化响应,从而提高了效率。

开发更好的推理模型的竞赛正如火如荼地进行。但是对于我们这些GPU资源有限的人来说又该如何是好?

多亏了Unsloth,我们现在在消费级GPU上仅用15GB的VRAM就可以训练15B参数模型。本文将介绍如何通过几个步骤使用GRPO训练自己的推理模型。

GRPO简介

 GRPO帮助AI模型通过比较答案来学习更好地思考。下面是它的工作原理:

  • 模型为一个问题编写多个答案。
  • 每个答案都有一个分数(比如答案正确、清晰、结构合理的相应得分)。
  • 得分求平均值,每个答案都与这个平均值进行比较。
  • 超过平均分的答案会得到奖励。
  • 随着时间的推移,模型逐渐学会生成得分更高的答案。

比如以数学为例:

  • 问:“2+2等于多少?”
  • 模型可能会输出:“2+2=5”(错误)或“2+2=4”(正确)。

GRPO奖励正确的答案,所以模型学会避免错误。这项技术允许模型在不需要大量标记数据集的情况下开发结构化推理。

训练自己的推理模型的逐步指南

本指南介绍了如何使用GRPO训练一个针对推理进行优化的LLM,并将其部署在Hugging Face上。我们将为本文使用meta-llama/meta-Llama-3.1-8B-Instruct以及Unsloth提供的参考笔记本:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb

第1步:环境设置

使用以下代码安装依赖项:

%%capture
# Install base packages
!pip install unsloth vllm
!pip install --upgrade pillow

# Install specific TRL version for GRPO support
!pip install git+https://github.com/huggingface/trl.git@e95f9fb74a3c3647b86f251b7e230ec51c64b72b

关键组件:

  • unsloth:经过优化的训练框架
  • vllm:高吞吐量推理引擎
  • trl:Transformer强化学习库

第2步:模型初始化

在所有函数之前使用PatchFastRL来修补GRPO及其他强化学习算法。这一步通过将特定的算法改进集成到FastLanguageModel中,确保模型针对强化学习任务进行了优化。然后加载使用以下参数的Llama 3.1 8B Instruct,并运用lora适配。

from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
import torch

# Enable GRPO patches
PatchFastRL("GRPO", FastLanguageModel)

# Configuration
max_seq_length = 512  # Increase for complex reasoning chains
lora_rank = 32        # Balance between capacity and speed

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

关键参数:

  • load_in_4bit:将内存使用减少4倍(量化)
  • fast_inference:启用vLLM的注意力优化
  • gpu_memory_utilization:控制VRAM分配缓冲区(本例中为60%)
  • r = lora_rank:控制允许多少LoRA适配。我们将其设置为32(秩越大=智能化越高,但速度越慢)

第3步:数据集准备

在这一步,我们准备了一个数据集,在生成答案之前逐步训练我们的模型进行推理。数据集的格式很重要,因为它影响模型如何构建响应的结构。基础笔记本最初使用GSM8K,这个数据集包含需要多步骤推理的85000个小学数学单词问题。然而,我们将使用一个不同的数据集,它提供了跨多个领域的更广泛的推理覆盖,可以在这里找到:https://huggingface.co/datasets/KingNish/reasoning-base-20k。

数据字段:

  • 用户:用户的查询或问题语句。
  • 助理:问题的正确答案。
  • 推理:详细的逐步推理过程解释了如何得出正确的答案。
  • 模板:预运用的RChatML聊天模板。

我们使用结构化的响应模板为数据集格式化,以确保我们的模型学会将推理与最终答案分开。

import re
from datasets import load_dataset, Dataset
from difflib import SequenceMatcher

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

现在,加载Reasoning Base 20K数据集。

def get_reasoning_questions(split="train") -> Dataset:
    data = load_dataset("KingNish/reasoning-base-20k", split=split)

    data = data.map(lambda x: {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": x["user"]}
        ],
        "reasoning": x["reasoning"],
        "answer": x["assistant"]
    })

    return data

# Load dataset
dataset = get_reasoning_questions()

第4步:奖励函数设计(最重要的步骤)

奖励函数在训练针对推理加以优化的模型中至关重要,因为它们指导模型“良好”表现的含义是什么。正确的奖励设计确保了模型生成逻辑合理、格式良好且高质量的响应。我们的数据集需要一种与GSM8K不同的方法,因为我们的响应包含详细的推理步骤,而不仅仅是数字答案。因此,我们的奖励函数评估以下多个方面:

  • 内容质量→与参考答案在语义上的一致性
  • 结构合规→XML样式的格式化
  • 过程质量→推理步骤的复杂性

在下面的示例代码中,你将发现几个奖励函数,每个函数专注于响应的不同方面。下面更详细地介绍这些函数:

(1)答案相关性奖励

这个函数测量模型的响应在问题提示和参考答案(如果有)两个方面涵盖关键术语有多到位。这确保了模型至少提到或解决问题中的关键主题。

  • 从问题、响应和参考答案中提取关键术语。
  • 如果超过30%的问题术语出现在响应中,则加0.5分。
  • 如果超过30%的参考答案出现在响应中,则加0.5分。
  • 确保模型正确且合乎逻辑地回答问题。
def answer_relevance_reward(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]["content"] for completion in completions]
    questions = [prompt[-1]["content"] for prompt in prompts]

    def check_relevance(response, question, reference):
        score = 0.0
        # Extract key terms from question
        question_terms = set(question.lower().split())
        response_terms = set(response.lower().split())
        reference_terms = set(reference.lower().split())

        # 1) Check if response addresses key terms from question
        if len(question_terms) > 0:
            common_qr = question_terms.intersection(response_terms)
            if len(common_qr) / len(question_terms) > 0.3:
                score += 0.5

        # 2) Check if response uses similar key terms as reference
        if len(reference_terms) > 0:
            common_rr = response_terms.intersection(reference_terms)
            if len(common_rr) / len(reference_terms) > 0.3:
                score += 0.5

        return score

    return [check_relevance(r, q, a) for r, q, a in zip(responses, questions, answer)]

(2)严格格式合规奖励

这个函数确保输出严格遵循所需的XML样式结构,以保持结构化推理的一致输出格式。如果格式正确,则奖励0.5,否则奖励0.0。

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"^\n.*?\n\n\n.*?\n\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

(3)软格式合规奖励

这种更灵活的奖励函数允许较小的偏差,但仍然需要适当的XML样式格式。如果匹配,奖励0.5分,否则奖励0.0分。如果严格格式过于僵硬,并且可能惩罚不影响可用性的小差异,这可能会有所帮助。

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r".*?\s*.*?"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.search(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

(4)XML标记计数奖励(启发式示例)

这个函数通过计算所需的标记来评估响应遵守预期XML结构的程度。如果出现额外的内容,它会惩罚,并提供部分给分,而不是二元奖励。

def count_xml(text) -> float:
    count = 0.0
    if text.count("\n") == 1:
        count += 0.125
    if text.count("\n\n") == 1:
        count += 0.125
    if text.count("\n\n") == 1:
        count += 0.125
        count -= len(text.split("\n\n")[-1]) * 0.001
    if text.count("\n") == 1:
        count += 0.125
        count -= (len(text.split("\n")[-1]) - 1) * 0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

实际上,你常常希望将一些或所有这些不同的信号结合起来计算最终的奖励分数。最初的笔记本使用int和正确性奖励函数,因为数据集包含单个数字答案。然而,鉴于我们的一般推理模型,更广泛的评估方法必不可少。因此,我们使用了以下奖励函数:

reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        answer_relevance_reward
    ]

第5步:GRPO训练配置与执行

现在,设置GRPO训练器和所有配置。我将max_steps从250减少到150以节省时间,并将num_generations从6减少到4以节省内存。然而,Unsloth建议至少跑300个步骤才能观察到明显的改善。所有其他配置保持不变,如下所示:

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 200,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 150,
    save_steps = 150,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

现在,不妨初始化并运行GRPO训练器:

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = reward_funcs,
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

训练日志让我们得以深入了解奖励趋势、损失值和响应质量改进。最初,奖励会因随机探索而波动,但随着时间的推移会逐渐改善。我在Colab T4 GPU上运行这个笔记本大约花了2小时7分钟,150个步骤后的最终训练损失为0.0003475。

第6步:模型评估

鉴于我们已经训练了模型,不妨比较基准LLaMA 3.1 8B Instruct与使用GRPO训练的模型各自的性能。

GRPO训练前

text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

输出:

There are 2 'r's in the word "strawberry".

基准模型错误地识别了“strawberry”中“r”的数量,暴露了事实推理方面的不足。

GRPO训练后

现在我们加载LoRA并进行测试:

model.save_lora("grpo_saved_lora")


text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

输出:

<reasoning>
To determine the number of 'r's in the word "strawberry," we need to spell it out and count the occurrences of 'r'. The word "strawberry" is spelled as S-T-R-A-W-B-E-R-R-Y. The letter 'r' appears in the 3rd, 8th, and 9th positions. 
</reasoning>
<answer> 
There are 3 'r's in the word "strawberry." 
</answer>

GRPO训练后,模型的准确率和推理能力有所提高,但仍然不够完美。由于它在T4 GPU上训练仅用了2小时,因此延长序列长度和训练时间将进一步提升其表现。

第7步:部署和扩展

一旦对模型进行了微调和评估,下一步就是将其部署到实际使用场景中,确保它可以有效地扩展。部署需要将模型转换成经过优化的格式,将其整合到推理服务器中,并通过API或应用程序让其可以访问。为了确保有效的推理,我们保存了训练好的LoRA适配器,并将它们推送到Hugging Face Hub以便访问。这允许其他人加载经过微调的模型,不需要大量的计算资源。

# Just LoRA adapters
if True: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
if True: model.push_to_hub_merged("kanwal-mehreen18/Llama3.1-8B-GRPO", tokenizer, save_method = "lora", token = "YOUR_HF_KEY")

将lora模型保存到https://huggingface.co/kanwal-mehreen18/Llama3.1-8B-GRPO。

Unsloth给出的最佳实践

  • 使用参数数量>1.5B的模型进行可靠推理。
  • 接受进行12小时的训练以处理复杂任务。
  • 结合多种奖励信号(3-5种函数最好)。

原文标题:DeepSeek-Level AI? Train Your Own Reasoning Model in Just 7 Easy Steps!,作者:Kanwal Mehreen

责任编辑:姜华 来源: 51CTO内容精选
相关推荐

2024-05-07 08:00:00

自然语言处理机器学习

2025-03-06 09:55:49

2022-08-02 20:22:01

SaaS安全网络攻击

2014-03-12 15:23:20

2010-04-09 09:55:43

Oracle sqlp

2025-01-21 11:53:53

2023-12-21 18:01:58

Docker容器部署

2025-03-06 10:14:39

2025-02-24 08:40:00

开源模型训练

2015-12-23 09:48:32

2023-03-06 08:48:52

2023-06-01 13:09:09

智能建筑数字孪生

2022-02-15 11:03:40

SD-WAN软件定义WAN

2023-07-10 13:28:43

智能建筑工具

2023-04-25 12:45:09

2024-02-19 00:21:45

开源图片

2025-02-08 14:03:25

2025-03-05 00:22:00

2022-05-30 15:44:33

模型训练GAN

2023-11-01 18:01:02

改进WakaTime编程
点赞
收藏

51CTO技术栈公众号