from typing import Optional, Dict
import re, logging, os, sys, torch, math
import transformers
from transformers import (
AutoModelForCausalLM,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
import datasets
from datasets import load_dataset
from trl import ModelConfig, ScriptArguments, GRPOConfig, GRPOTrainer, get_peft_config
from dataclasses import dataclass, field
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
logger = logging.getLogger(__name__)
def verify_answer(contents, solution):
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_cnotallow=[LatexExtractionConfig()],
)
print('-'*100)
print(f'\ncontent:{content}\nsol:{sol}')
if len(gold_parsed) != 0:
answer_parsed = parse(
content,
extraction_cnotallow=[
LatexExtractionConfig(
normalization_cnotallow=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equatinotallow=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
print('-'*100)
print(f'\nanswer_parsed:{answer_parsed}\ngold_parsed:{gold_parsed}\nreward:{reward}')
else:
reward = 1.0
print(f'Failed to parse gold solution: {sol}')
rewards.append(reward)
return rewards
def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = verify_answer(contents, solution)
print(f'\naccuracy rewards:{rewards}')
return rewards
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
rewards = [1.0if match else0.0for match in matches]
print('-'*100)
print('\nformat rewards:', rewards)
return rewards
def reasoning_steps_reward(completions, **kwargs):
"""Reward function that checks for clear step-by-step reasoning.
Regex pattern:
Step \d+: - matches "Step 1:", "Step 2:", etc.
^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
\n- - matches bullet points with hyphens
\n\* - matches bullet points with asterisks
First,|Second,|Next,|Finally, - matches transition words
"""
pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [len(re.findall(pattern, content)) for content in completion_contents]
# Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward
return [min(1.0, count / 3) for count in matches]
def len_reward(completions: list[Dict[str, str]], solution: list[str], **kwargs) -> float:
"""Compute length-based rewards to discourage overthinking and promote token efficiency.
Taken from from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599
Args:
completions: List of model completions
solutions: List of ground truth solutions
Returns:
List of rewards where:
- For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
- For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
"""
contents = [completion[0]["content"] for completion in completions]
# First check correctness of answers
correctness = verify_answer(contents, solution)
# Calculate lengths
lengths = [len(content) for content in contents]
min_len = min(lengths)
max_len = max(lengths)
# If all responses have the same length, return zero rewards
if max_len == min_len:
return [0.0] * len(completions)
rewards = []
for length, is_correct in zip(lengths, correctness):
lambda_val = 0.5 - (length - min_len) / (max_len - min_len)
reward = lambda_val if is_correct > 0.0else min(0, lambda_val)
rewards.append(float(reward))
return rewards
def get_cosine_scaled_reward(
min_value_wrong: float = -1.0,
max_value_wrong: float = -0.5,
min_value_correct: float = 0.5,
max_value_correct: float = 1.0,
max_len: int = 1000,
):
def cosine_scaled_reward(completions, solution, **kwargs):
"""Reward function that scales based on completion length using a cosine schedule.
Shorter correct solutions are rewarded more than longer ones.
Longer incorrect solutions are penalized less than shorter ones.
Args:
completions: List of model completions
solution: List of ground truth solutions
This function is parameterized by the following arguments:
min_value_wrong: Minimum reward for wrong answers
max_value_wrong: Maximum reward for wrong answers
min_value_correct: Minimum reward for correct answers
max_value_correct: Maximum reward for correct answers
max_len: Maximum length for scaling
"""
contents = [completion[0]["content"] for completion in completions]
rewards = []
correctness = verify_answer(contents, solution)
lengths = [len(content) for content in contents]
for gen_len, is_correct in zip(lengths, correctness):
# Apply cosine scaling based on length
progress = gen_len / max_len
cosine = math.cos(progress * math.pi)
if is_correct > 0:
min_value = min_value_correct
max_value = max_value_correct
else:
# Swap min/max for incorrect answers
min_value = max_value_wrong
max_value = min_value_wrong
reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
rewards.append(float(reward))
return rewards
return cosine_scaled_reward
def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
"""
Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
Args:
ngram_size: size of the n-grams
max_penalty: Maximum (negative) penalty for wrong answers
"""
if max_penalty > 0:
raise ValueError(f"max_penalty {max_penalty} should not be positive")
def zipngram(text: str, ngram_size: int):
words = text.lower().split()
return zip(*[words[i:] for i in range(ngram_size)])
def repetition_penalty_reward(completions, **kwargs) -> float:
"""
reward function the penalizes repetitions
ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py
Args:
completions: List of model completions
"""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for completion in contents:
if completion == "":
rewards.append(0.0)
continue
if len(completion.split()) < ngram_size:
rewards.append(0.0)
continue
ngrams = set()
total = 0
for ng in zipngram(completion, ngram_size):
ngrams.add(ng)
total += 1
scaling = 1 - len(ngrams) / total
reward = scaling * max_penalty
rewards.append(reward)
return rewards
return repetition_penalty_reward
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
@dataclass
class R1GRPOScriptArguments(ScriptArguments):
reward_funcs: list[str] = field(
default_factory = lambda: ["accuracy", "format"],
metadata = {
"help": f"List of reward functions. Available options: 'accuracy', 'format', 'reasoning_steps', 'len', 'get_cosine_scaled', 'get_repetition_penalty'"
},
)
cosine_min_value_wrong: float = field(
default=0.0,
metadata={"help": "Minimum reward for wrong answers"},
)
cosine_max_value_wrong: float = field(
default=-0.5,
metadata={"help": "Maximum reward for wrong answers"},
)
cosine_min_value_correct: float = field(
default=0.5,
metadata={"help": "Minimum reward for correct answers"},
)
cosine_max_value_correct: float = field(
default=1.0,
metadata={"help": "Maximum reward for correct answers"},
)
cosine_max_len: int = field(
default=1000,
metadata={"help": "Maximum length for scaling"},
)
repetition_n_grams: int = field(
default=3,
metadata={"help": "Number of n-grams for repetition penalty reward"},
)
repetition_max_penalty: float = field(
default=-1.0,
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
)
@dataclass
class R1GRPOConfig(GRPOConfig):
"""
args for callbacks, benchmarks etc
"""
benchmarks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
)
callbacks: list[str] = field(
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
)
system_prompt: Optional[str] = field(
default=None, metadata={"help": "The optional system prompt to use for benchmarking."}
)
def main(script_args, training_args, model_args):
# Set seed for reproducibility
set_seed(training_args.seed)
###############
# Setup logging
###############
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process a small summary
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Data parameters {training_args}")
# Check for last checkpoint
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
logger.info(f"Last checkpoint detected, resuming training at {last_checkpoint=}.")
if last_checkpoint isnotNoneand training_args.resume_from_checkpoint isNone:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
# Get reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": get_cosine_scaled_reward(
min_value_wrnotallow=script_args.cosine_min_value_wrong,
max_value_wrnotallow=script_args.cosine_max_value_wrong,
min_value_correct=script_args.cosine_min_value_correct,
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
dataset = dataset.map(make_conversation)
for split in dataset:
if"messages"in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
training_args.gradient_checkpointing = True
model_kwargs = dict(
revision = model_args.model_revision,
trust_remote_code = model_args.trust_remote_code,
attn_implementation = model_args.attn_implementation,
torch_dtype = torch_dtype,
use_cache = Falseif training_args.gradient_checkpointing elseTrue,
)
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,
load_in_4bit=False, **model_kwargs)
print(model_args.model_name_or_path)
#############################
# Initialize the R1GRPO trainer
#############################
trainer = GRPOTrainer(
model = model,
reward_funcs = reward_funcs,
args = training_args,
train_dataset = dataset[script_args.dataset_train_split],
eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no"elseNone,
peft_config = get_peft_config(model_args),
)
###############
# Training loop
###############
logger.info("*** Train ***")
checkpoint = None
if training_args.resume_from_checkpoint isnotNone:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint isnotNone:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
##################################
# Save model and create model card
##################################
logger.info("*** Save model ***")
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
# Save everything else on main process
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["GRPOTrainer-R1"],
}
if trainer.accelerator.is_main_process:
trainer.create_model_card(**kwargs)
# Restore k,v cache for fast inference
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
script_config = {
"dataset_name": "AI-MO/NuminaMath-TIR",
"dataset_config": "default",
"reward_funcs": [
"accuracy",
"format",
"reasoning_steps",
]
}
training_config = {
"output_dir": "output/GRPO-R1-1.5B", # 模型输出目录
"overwrite_output_dir": True, # 是否覆盖输出目录
"do_eval": True, # 是否进行评估
"eval_strategy": "steps", # 评估策略,按步数进行评估
"eval_steps": 100, # 每100步进行一次评估
"per_device_train_batch_size": 4, # 每个设备上的训练批次大小
"per_device_eval_batch_size": 4, # 每个设备上的评估批次大小
"gradient_accumulation_steps": 8, # 梯度累积步数
"learning_rate": 1.0e-06, # 学习率
"num_train_epochs": 1.0, # 训练的总轮数
"max_steps": -1, # 最大训练步数,-1表示不限制
"lr_scheduler_type": "cosine", # 学习率调度器类型,使用余弦退火
"warmup_ratio": 0.1, # 预热比例
"log_level": "info", # 日志记录级别
"logging_strategy": "steps", # 日志记录策略,按步数记录
"logging_steps": 100, # 每100步记录一次日志
"save_strategy": "no", # 保存策略,不保存
"seed": 42, # 随机种子
"bf16": True, # 是否使用bfloat16精度
"gradient_checkpointing": True, # 是否使用梯度检查点
"gradient_checkpointing_kwargs": {
"use_reentrant": False# 梯度检查点的额外参数,是否使用reentrant模式
},
"max_prompt_length": 128, # 最大提示长度
"num_generations": 4, # 生成的数量
"max_completion_length": 256, # 最大完成长度
"use_vllm": True, # 是否使用vLLM
"vllm_device": "auto", # vLLM设备,自动选择
"vllm_gpu_memory_utilization": 0.8, # vLLM GPU内存利用率
"resume_from_checkpoint": "output/GRPO-R1-1.5B", # 恢复检查点,如果没有latest文件,需要添加latest文件类似`global_step9055`
}
model_config = {
"model_name_or_path": "Qwen/Qwen2.5-1.5B-Instruct",
"model_revision": "main",
"torch_dtype": "bfloat16",
"attn_implementation": "flash_attention_2",
}
if __name__ == "__main__":
script_args = R1GRPOScriptArguments(**script_config)
training_args = R1GRPOConfig(**training_config)
model_args = ModelConfig(**model_config)
main(script_args, training_args, model_args)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.
- 94.
- 95.
- 96.
- 97.
- 98.
- 99.
- 100.
- 101.
- 102.
- 103.
- 104.
- 105.
- 106.
- 107.
- 108.
- 109.
- 110.
- 111.
- 112.
- 113.
- 114.
- 115.
- 116.
- 117.
- 118.
- 119.
- 120.
- 121.
- 122.
- 123.
- 124.
- 125.
- 126.
- 127.
- 128.
- 129.
- 130.
- 131.
- 132.
- 133.
- 134.
- 135.
- 136.
- 137.
- 138.
- 139.
- 140.
- 141.
- 142.
- 143.
- 144.
- 145.
- 146.
- 147.
- 148.
- 149.
- 150.
- 151.
- 152.
- 153.
- 154.
- 155.
- 156.
- 157.
- 158.
- 159.
- 160.
- 161.
- 162.
- 163.
- 164.
- 165.
- 166.
- 167.
- 168.
- 169.
- 170.
- 171.
- 172.
- 173.
- 174.
- 175.
- 176.
- 177.
- 178.
- 179.
- 180.
- 181.
- 182.
- 183.
- 184.
- 185.
- 186.
- 187.
- 188.
- 189.
- 190.
- 191.
- 192.
- 193.
- 194.
- 195.
- 196.
- 197.
- 198.
- 199.
- 200.
- 201.
- 202.
- 203.
- 204.
- 205.
- 206.
- 207.
- 208.
- 209.
- 210.
- 211.
- 212.
- 213.
- 214.
- 215.
- 216.
- 217.
- 218.
- 219.
- 220.
- 221.
- 222.
- 223.
- 224.
- 225.
- 226.
- 227.
- 228.
- 229.
- 230.
- 231.
- 232.
- 233.
- 234.
- 235.
- 236.
- 237.
- 238.
- 239.
- 240.
- 241.
- 242.
- 243.
- 244.
- 245.
- 246.
- 247.
- 248.
- 249.
- 250.
- 251.
- 252.
- 253.
- 254.
- 255.
- 256.
- 257.
- 258.
- 259.
- 260.
- 261.
- 262.
- 263.
- 264.
- 265.
- 266.
- 267.
- 268.
- 269.
- 270.
- 271.
- 272.
- 273.
- 274.
- 275.
- 276.
- 277.
- 278.
- 279.
- 280.
- 281.
- 282.
- 283.
- 284.
- 285.
- 286.
- 287.
- 288.
- 289.
- 290.
- 291.
- 292.
- 293.
- 294.
- 295.
- 296.
- 297.
- 298.
- 299.
- 300.
- 301.
- 302.
- 303.
- 304.
- 305.
- 306.
- 307.
- 308.
- 309.
- 310.
- 311.
- 312.
- 313.
- 314.
- 315.
- 316.
- 317.
- 318.
- 319.
- 320.
- 321.
- 322.
- 323.
- 324.
- 325.
- 326.
- 327.
- 328.
- 329.
- 330.
- 331.
- 332.
- 333.
- 334.
- 335.
- 336.
- 337.
- 338.
- 339.
- 340.
- 341.
- 342.
- 343.
- 344.
- 345.
- 346.
- 347.
- 348.
- 349.
- 350.
- 351.
- 352.
- 353.
- 354.
- 355.
- 356.
- 357.
- 358.
- 359.
- 360.
- 361.
- 362.
- 363.
- 364.
- 365.
- 366.
- 367.
- 368.
- 369.
- 370.
- 371.
- 372.
- 373.
- 374.
- 375.
- 376.
- 377.
- 378.
- 379.
- 380.
- 381.
- 382.
- 383.
- 384.
- 385.
- 386.
- 387.
- 388.
- 389.
- 390.
- 391.
- 392.
- 393.
- 394.
- 395.
- 396.
- 397.
- 398.
- 399.
- 400.
- 401.
- 402.
- 403.
- 404.
- 405.
- 406.
- 407.
- 408.
- 409.
- 410.
- 411.
- 412.
- 413.
- 414.
- 415.
- 416.
- 417.
- 418.
- 419.
- 420.
- 421.
- 422.
- 423.
- 424.
- 425.
- 426.
- 427.
- 428.
- 429.
- 430.
- 431.
- 432.
- 433.
- 434.
- 435.
- 436.
- 437.
- 438.
- 439.
- 440.
- 441.
- 442.
- 443.
- 444.
- 445.
- 446.
- 447.
- 448.
- 449.
- 450.
- 451.
- 452.
- 453.
- 454.
- 455.
- 456.
- 457.
- 458.
- 459.
- 460.
- 461.
- 462.
- 463.
- 464.
- 465.
- 466.
- 467.
- 468.
- 469.
- 470.
- 471.
- 472.
- 473.
- 474.
- 475.