一行代码Post-Train任意长序列!360智脑开源360-LLaMA-Factory

人工智能 新闻
360 智脑开源了 360-LLaMA-Factory,支持了序列并行,仅需额外 1 个参数控制。基于 LLaMA-Factory 和 ring-flash-attention 开发,360-LLaMA-Factory 的实现模块化、效果正确且在长序列上有效。

项目核心开发者 Haosheng Zou 本科毕业于清华大学电子系,博士毕业于清华大学计算机系朱军教授组,目前在 360 智脑从事长文本和强化学习等后训练工作。开发者 Xiaowei Lv 目前在人民大学信息学院研二在读。Fenrui Xiao、Junchen Liu、Qi An 和 Xiaodong Sun 等在开发测试中亦有贡献。

大模型长序列的处理能力已越来越重要,像复杂长文本任务、多帧视频理解任务、以及 OpenAI 近期发布的 o1、o3 系列模型的高计算量模式,需要处理的输入 + 输出总 token 数从几万量级上升到了几百万量级。面对模型日益增长的长序列需求,在预训练(Pre-Training)和后训练(Post-Training)阶段,所用的平台框架都需要支持更长序列数据的训练。不同于预训练阶段基于 Megatron-LM 定制开发的常见选择,后训练阶段因后训练算法的多样性(比如仅 DPO 就有几十个变种)和训练需求的灵活性,至今没有一个框架同时在并行策略、后训练算法、GPU 显存优化和简单易用这 4 个方面上全部做到兼容并包。

在所有开源的后训练框架中,LLaMA-Factory 是用户最多的框架之一(GitHub star 数已 37k 多),保持长期迭代更新,支持丰富的模型和后训练算法,有各种 GPU 显存优化技巧和简单易用的方式。然而,LLaMA-Factory 在长序列后训练上支持仍有所欠缺,尚不支持长序列的关键技术 —— 序列并行。

图片

项目主页:https://github.com/Qihoo360/360-LLaMA-Factory

最近,360 智脑基于 LLaMA-Factory 开源了 360-LLaMA-Factory,加入了序列并行功能,一行代码即可支持任意长序列的后训练(Post-Training)—— 仅需额外指定序列并行一个参数:

sequence_parallel_size: 16

按需增加序列并行的 GPU 卡数,即可在任意长度的序列上 SFT 或 DPO。

360-LLaMA-Factory 的实现经过了严格的正确性验证,已在主仓 Pull Request 中审核过。正式合并进 LLaMA-Factory 主仓之前,可先使用 360-LLaMA-Factory。

1、项目背景与项目简介

360 智脑早在 2023 年就开始了长文本大模型的研发,到目前为止已经成功应用于开源并更新了两个版本的 360Zhinao-7B-Chat-360k 模型,以及近日发布的长思维链推理模型 360gpt2-o1。在 360-LLaMA-Factory 中,我们将 360 智脑内部长序列后训练能力系统性地整合进了 LLaMA-Factory 中,用户仅需额外添加一行代码,即可进行理论上任意长度的长序列后训练(增加序列并行的 GPU 卡数即可):

sequence_parallel_size: 16

在原先使用 LLaMA-Factory 的基础上,只需额外增加一个参数

通过这种方式,360-LLaMA-Factory 将 LLaMA-Factory 的序列并行也做到了简单易用和兼容并包,和 LLaMA-Factory 的其他功能完全兼容。

粗粒度地测试 8 卡 80G 的全参数后训练(不考虑除了 zero3-offload 和 gradient checkpointing 外的任何优化技巧),360-LLaMA-Factory 至少可以训到 SFT 210k (7B) / 128k (72B) 和 DPO 84k (7B) / 46k (72B)。若加上注掉 logits = logits.float () 和 DPO 预计算等技巧,2 卡序列并行即可解决诸多常见的训练需求。360-LLaMA-Factory 让序列并行也真正成为了简单好用、效果也好的后训练工具。

作为开源社区的一份子,360-LLaMA-Factory 离不开 LLaMA-Factory、ring-flash-attention 和 EasyContext 等开源项目的开创性工作,我们的底层开发部分依赖了这些工作,但也有我们自己在具体实现方式上的不同和见解。我们相信我们的代码实现已做到尽可能好的模块化和尽可能少的原始代码修改,且严格检查过正确性,因此也已向 LLaMA-Factory 主仓提交了 Pull Request,初步审核通过。我们乐于同开源社区共建完善这项工作。

2、长序列及其后训练

2.1 长序列大模型的训练:预训练 vs 后训练

随着大模型训练数据长度的增长,预训练和后训练平台框架都需要支持长序列数据训练。

  • 预训练阶段,英伟达的 Megatron-LM 凭借丰富高效的并行策略与出色的 GPU 显存优化,成为主流框架,基于它的定制开发往往是最通用的解法, Megatron-LM 本身已实现了序列并行(Megatron-LM 称之为 context parallelism,其他工作一般称为 sequence parallelism)。

  • 后训练阶段情况相对复杂。后训练算法多样,如 DPO 就有诸多变种,且训练需求灵活多变,不同场景对算法、资源、并行性等要求各异。因此,至今没有一个框架能在并行策略、后训练算法、GPU 显存优化和易用性这四个关键方面做到近乎完美的兼容。虽有框架在部分方面表现尚可,但总体仍存在短板,这也限制了模型在长序列数据后训练上的进一步发展。

2.2 长序列的通解 —— 序列并行及其难点

长序列后训练面临的关键瓶颈是:序列长度增加时,激活显存会大幅上升。虽然有 unsloth、liger kernel、LoRA 等多种降低显存占用的技巧,但均未从根本上解决序列长度增加的本质问题,其效果存在明确上限。

序列并行(sequence parallelism)被认为是解决长序列训练问题的通解,它通过把一条长序列切分到不同的显卡上进行计算,从而避免了每张显卡处理过长的序列,从根本上解决了 “每张显卡处理的序列长度增加” 的问题。然而,序列并行的实现难度较大,需要在切分后的序列之间进行通信计算 attention,需要侵入修改原始的 attention 函数。在开源的 Megatron-LM 中,序列并行也是所有并行策略中最后才添加的,LLaMA-Factory 之前还没有支持序列并行。

2.3 序列并行后训练的相关工作

我们调研了其他一些支持序列并行的开源框架,有些实现上有错或小 bug、导致支持的后训练算法不全;有些更新维护不及时、训练较新的模型不方便、显示进度条等易用性不足。有的与 LLaMA-Factory 相比继承依赖更少,支持功能较少但更干净、更适合定制开发,有不同的使用场景。此外,各家的序列并行具体实现也不尽相同。详见下面的表 1 和 GitHub README,有未调研到的也请包涵并联系 360-LLaMA-Factory。

图片

表 1:一些支持序列并行的后训练框架对比

3、360-LLaMA-Factory 框架解析

360-LLaMA-Factory 系统性地为 LLaMA-Factory 增加了序列并行的支持。以下将简要介绍 360-LLaMA-Factory 框架中的模块化修改和执行流程。

3.1 360-LLaMA-Factory 的框架和模块化封装

360-LLaMA-Factory 将序列并行的代码做到了尽可能好的模块化和尽可能少的原始代码修改。

我们认为序列并行本质上应认为是对模型的修改,因此在 model_args 中增加了参数并抽象为 apply_sequence_parallel 修改模型的函数。

# src/llamafactory/model/loader.py
sequence_parallel_group = apply_sequence_parallel(model_args)  # 序列并行monkey patch,改动attention计算
...
model.sequence_parallel_group = sequence_parallel_group  # 维护模型的序列并行组,不开则为None

相应地,数据处理部分也要相应地修改,我们将 zigzag ring attention 所需的数据处理抽象成了一个 decorator,装饰原来的数据处理函数。背后,这会将先 shuffle、packing、预处理好的数据进一步做好序列并行的准备:先将每行 pad 或截断到指定的训练长度,再按 zigzag 切分并按顺序写入数据集,最后在训练时用 SequentialSampler 读取训练数据。

# src/llamafactory/data/loader.py
@sequence_parallel_decorator
def get_dataset(...)

loss 计算则需要在 Trainer 中做序列并行组内的 reduce 汇总和计算。

# src/llamafactory/train/sft/trainer.py
dist.all_reduce(loss, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(label_num, op=dist.ReduceOp.SUM, group=sp_group)
loss /= label_num
# src/llamafactory/train/dpo/trainer.py
dist.all_reduce(policy_chosen_logps, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(policy_rejected_logps, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(reference_chosen_logps, op=dist.ReduceOp.SUM, group=sp_group)
dist.all_reduce(reference_rejected_logps, op=dist.ReduceOp.SUM, group=sp_group)

3.2 360-LLaMA-Factory 的 SFT 和 DPOTrainer

除了统一的模块化抽象,序列并行也需要对 360-LLaMA-Factory 的 Trainer 稍做定制化的修改,以适配各底层库。针对最普遍的后训练需求 SFT 和 DPO(及其变种),我们对 360-LLaMA-Factory 中的 SFT 和 DPOTrainer 做了尽可能少且清晰的修改。

其中,dummy_forward 是因为我们发现基于目前的底层序列并行实现,在第一次 forward 时 DPO loss 不等于 log (sigmoid (0)),但学习率设为 0 时之后的 DPO loss 全都等于。因此,训练最开始时先做且仅做一次假前传,不对正式训练循环造成任何影响。

从 SFT 和 DPO 的序列并行对比图中,可以清晰地看出 360-LLaMA-Factory 序列并行带来的改动。

图片

图 3:360-LLaMA-Factory SFT 序列并行对比


图片

图 4:360-LLaMA-Factory DPO 序列并行对比

4、360-LLaMA-Factory 效果验证

内部 360-LLaMA-Factory 的早期版本已训练了开源的 360Zhinao2-7B-Chat-360k。

为验证本次开源的 360-LLaMA-Factory 的正确性,我们用总量为 30 条的小数据集,验证了序列并行开与不开的对比情况下,训练曲线的差别,以此来确保 360-LLaMA-Factory 所有实现的正确性。从下图可见,序列并行对训练曲线的影响几乎可以忽略不计,DPO 稍有一定数值误差,但我们也仔细检查了该误差与 DeepSpeed Ulysses 的误差范围一致,很可能部分是并行计算本身的随机性导致的,亦可参考 ring-flash-attention 的详细说明。

图片

图 5:360-LLaMA-Factory SFT 和 DPO 序列并行开关对比

为便于对比效果,我们基于第三方全尺寸开源模型粗粒度压测了最大训练长度,如下表 2、表 3 所示,可见 8 卡 80G 的序列并行上限已可满足几十至几百 k 超长序列的需求:

图片

表 2:第三方开源模型多尺寸 SFT 长度压测


图片

表 3:第三方开源模型多尺寸 DPO 长度压测

5、总结

360 智脑开源了 360-LLaMA-Factory,支持了序列并行,仅需额外 1 个参数控制。基于 LLaMA-Factory 和 ring-flash-attention 开发,360-LLaMA-Factory 的实现模块化、效果正确且在长序列上有效。

欢迎开发者们使用和开发。在本仓库(https://github.com/Qihoo360/360-LLaMA-Factory)下提交序列并行相关的 issue 或 PR 即可。

也欢迎研究者们,尤其是依赖长序列大模型的研究者们,在研究中使用我们的代码,可以这样引用我们的工作:

@software{360-llama-factory,
  author = {Haosheng Zou, Xiaowei Lv, Shousheng Jia and Xiangzheng Zhang},
  title = {360-LLaMA-Factory},
  url = {https://github.com/Qihoo360/360-LLaMA-Factory},
  year = {2024}
}
责任编辑:张燕妮 来源: 机器之心
相关推荐

2023-06-13 17:40:49

360360智脑大模型

2023-09-05 10:21:03

人工智能

2016-12-02 08:53:18

Python一行代码

2010-12-03 12:57:23

2024-08-13 15:40:00

2014-02-12 13:43:50

代码并行任务

2017-04-05 11:10:23

Javascript代码前端

2022-04-09 09:11:33

Python

2013-11-27 09:25:20

2020-09-09 16:00:22

Linux进程

2021-11-02 16:25:41

Python代码技巧

2020-08-19 10:30:25

代码Python多线程

2017-04-13 19:20:18

Python代码并行任务

2021-08-31 09:49:37

CPU执行语言

2023-09-12 10:10:57

开发者工具开源

2010-04-23 21:42:14

信息安全产品360安全中心

2010-09-27 14:22:23

2020-09-28 12:34:38

Python代码开发

2020-12-08 06:20:00

Python自动化工具开源
点赞
收藏

51CTO技术栈公众号