1、为什么需要DPO
Rafailov等人在2023年发表了一篇论文《Direct Preference Optimization: Your Language Model is Secretly a Reward Model》,该论文提出了一种新的训练方法,称为直接偏好优化(DPO),该论文介绍:
由此可见,DPO 主要解决RLHF不稳定的问题,直接使用人类偏好数据训练模型。
2、DPO的训练原理
DPO 的训练原理如下图所示(出自原论文):
DPO
主要包括两个步骤:
- 数据收集:收集一个偏好数据集,其中包含给定提示的生成结果的正负选择对;
- 优化:直接最大化 DPO 损失的对数似然函数,该损失函数是偏好数据集上的交叉熵损失和模型生成结果的对数似然性之间的加权平均值;
具体公式推导可以参考这篇博客:https://www.cnblogs.com/lemonzhang/p/17910358.html。
3、DPO的代码实现
3.1 收集数据
DPO 训练器对数据集的格式有具体的要求,包括三个部分:
- 提示(prompt):提示的格式为:prompt: 文本;
- 选中(chosen):选中文本的格式为:chosen: 文本;
- 拒绝(rejected):拒绝选中文本的格式为:rejected: 文本;
- 示例:
DPO的数据可以搜索huggingface的DPO数据集,地址为:https://huggingface.co/datasets?sort=trending&search=dpo 。
比如 https://huggingface.co/datasets/Anthropic/hh-rlhf 的数据集如下:
hh-rlhf
3.2 TRL
引入 TRL 库,支持 DPO 训练器,训练样例代码:
如上训练默认是保存 safetensors 格式的模型,如果想保存 pytorch 格式的模型, 可以改为如下代码:
3.3 训练
Transformer的代码和前面的一样,可以参考预训练的代码,如下就是初始化模型和 DPO 训练的代码:
- init_model 函数主要是注册和加载预训练的模型,并将 tokeinzer 的一些配置文件都拷贝到 ./my_checkpoint 方便后续的训练;
- DPOConfig 主要是配置训练的一些参数,比如保存的模型路径、学习率等;
- DPOTrainer 是 DPO 训练器,将模型载入后调用 train 进行训练,参数说明如下:
model: transformers.PreTrainedModel,预训练模型
ref_model: transformers.PreTrainedModel,参考模型
args: DPOConfig,用于训练的 DPO 配置参数
train_dataset: datasets.Dataset,训练数据集
tokenizer: transformers.PreTrainedTokenizerBase,分词器
model_init: 用于训练的模型初始化器,如果指定为 None,则将使用默认的模型初始化器
optimizer: torch.optim.Optimizer,优化器
callbacks: 用于训练的回调函数
- dpo_trainer.save_model 保存模型,传入 output_dir 参数,指定保存的模型路径
4、总结
至此,训练系列按照步骤写完了,现在总结训练流程:
模型训练流程
不过验证下来,训练效果不是很好,这个也是从0开始训练会遇到的问题,因此接下来会完成几个事项:
- 模型迭代优化,解决训练效果不好的问题;
- 模型尝试新的模型和解决方案,解决训练速度问题;
- 加入多模态训练集,将语言大模型改进为多模态模型;
- 最后将整个模型训练完成后,将代码开源。