
重磅!Unsloth开源新算法:让GRPO训练大模型所需显存降低90%,告别显存焦虑!
图片
在大模型训练领域,显存一直是一个让研究者和开发者头疼的问题。特别是在进行长文本上下文训练时,动辄需要几百GB的显存需求,这让很多研究者望而却步。不过最近,AI基础设施优化团队Unsloth带来了一个重大突破 - 他们推出的新算法可以让GRPO训练所需显存减少高达90%!文章公布了Llama3.1(8B) GRPO在Colab上notebook,见:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb
1、从510GB到54GB:显存优化的突破性进展
在传统的GRPO训练方案中,要训练一个支持20K上下文长度的Llama 3.1(8B)模型,需要高达510.8GB的显存。这个量级的显存需求,即便是顶配的训练服务器也难以满足。而Unsloth团队通过其创新的算法优化,将这一需求降低到了惊人的54.3GB,这意味着:
训练内存成本:从414GB降至42GB
GRPO内存成本:从78.3GB降至9.8GB
推理内存开销:从16GB降至0GB
20K上下文的推理KV缓存:保持在2.5GB
图片
2、技术创新:三重优化方案
Unsloth团队采用了三个关键的技术创新来实现这一突破:
全新的线性算法:团队为GRPO开发了一个全新的内存高效线性算法,这个优化alone就减少了68.5GB的内存使用。更令人惊喜的是,通过torch.compile的协助,这个算法在性能上还实现了提速。
智能梯度检查点:通过将中间激活值异步卸载到系统RAM,在仅损失1%性能的情况下节省了惊人的372GB显存。这个优化特别适用于需要多次生成的场景。
共享内存空间:与其他实现不同,Unsloth可以与底层推理引擎(vLLM)共享GPU/CUDA内存空间,这又节省了16GB显存。
Unsloth团队从 Horace 的线性交叉熵实现中获得了灵感,并成功使其适用于 GRPO!实际上,我们发现了一些令人惊讶的点:
参考 GRPO 实现使用反向 KL 散度,而不是正向 KL 散度。
天真地实现浮点 16 混合精度(以及浮点 8)上的线性交叉熵,如果没有正确处理,将因自动混合精度缩放机制而崩溃。
我们发现 GRPO 损失函数实现中存在其他问题——主要是在反向 KL 散度的公式表达上。
Unsloth团队进行了 4 个实验:
通过参考实现(红线)进行常规 GRPO
移除断开代码(蓝色线条)
完整反向 KL,如前所述增加一个额外项(黄色线)
前向 KL 散度(绿色线)
图片
一般来说,移除 detach 确实会破坏所有训练,所以我们必须保留它——这很可能需要更多的调查。看起来其他所有实现似乎都很相似?我们可能需要运行模型更长时间以看到不同的效果。
在所有实现中,Unsloth团队还利用了 logsumexp 技巧
3、实践意义:让更多开发者参与AI训练
这项技术突破的意义远不止于数字的优化。它意味着:
- 降低硬件门槛:原本需要多卡集群才能完成的训练任务,现在用单卡就能搞定。比如Qwen2.5 (1.5B)的训练现在只需要5GB显存!
- 提升研究效率:研究人员可以更快速地进行实验验证,加快模型迭代速度。
- 扩大应用场景:更多的小团队和个人开发者现在也能尝试大模型训练,这将极大促进AI技术的普及和创新。
看完这篇文章,是不是对AI训练的未来更有信心了?如果你也对大模型训练感兴趣,不妨关注Unsloth团队的GitHub项目,开启你的AI训练之旅!
文章标题:Long-context GRPO 长上下文 GRPO
文章链接:https://unsloth.ai/blog/grpo
本文转载自 AI帝国,作者: 无影寺
