LLM 分布式训练六大关键技术介绍 原创 精华
编者按: 本文聚焦于分布式去中心化神经网络训练技术,作者系统阐述了在大规模模型训练中提高硬件使用效率的创新方法。
文章重点阐述了六种关键的分布式训练技术:
- 数据并行训练:通过将数据 mini-batches 分散到多个 workers,实现并行梯度计算和高效训练。
- Butterfly All-Reduce:通过创新的数据分割和汇总方法,有效降低通信成本。
- Gossip-Based Averaging:去中心化的通信策略,提高系统的容错性和可扩展性。
- Moshpit Gradient Descent:允许 workers 在小型独立组内进行梯度平均,增强训练的容错能力。
- DiLoCo:创新的内外优化算法,结合局部和全局参数更新,平衡收敛速度和系统性能。
- SWARM:引入动态任务分配和容错机制,优化异构硬件环境下的资源配置。
作者 | Robert Lange
编译 | 岳扬
随着人工智能技术的发展进步,训练大规模神经网络(包括大语言模型)变得越来越重要。这些模型的规模和复杂度不断提升,不仅增加了训练的成本和能耗,也迫切要求我们提高硬件使用效率。为了应对这些挑战,研究人员和工程师们正在探索分布式去中心化训练方法。本文将探讨多种分布式训练技术,例如数据并行训练方法和 Gossip-Based Averaging 方法,展示这些技术如何在满足该领域不断增长的需求的同时优化模型训练效率。
一幅以简约日式风格绘制的GPU集群图,图中加入了很多小型 GPU(由 OpenAI 的 Dallé-3 API 生成)
01 数据并行训练技术、全归约操作与节点同步
数据并行训练技术通过将数据的 mini-batches 分散到多个工作节点(workers)上,实现了高效的训练。这种方法不仅加快了训练进程,因为多个 workers 可以并行计算梯度,而且还使得我们可以处理比单个设备更大的 batch sizes。为了保持所有 workers 之间的模型更新同步,我们采用了全归约操作。该操作会将所有 workers 的梯度汇总并求平均值,然后统一更新模型,确保整个分布式系统中的模型保持一致。
以下是用 PyTorch 在 Python 中展示这一过程的一个简单示例:
全归约操作之外,还有一种方法是使用参数服务器(parameter server)。在这种架构中,中央服务器负责收集梯度信息并监控优化器的状态。虽然这样做可以简化同步过程,但同时也存在单点故障的风险,并有可能成为系统性能的瓶颈。
分布式训练中,Hogwild(Recht et al., 2011)[1]是另一项著名的技术。它采用异步更新模型参数的方法,无需所有计算节点同步即可进行。这种方法不仅适用于监督学习,也适用于强化学习(RL)场景,如异步演员-评论家算法(A3C, Mnih et al., 2016)[2]。在 A3C 中,多个智能体可以同时与环境互动,并基于各自的经验异步更新同一个模型。这样做不仅提高了资源的使用效率,还能通过多个智能体的不同经验加快收敛速度,从而提高在复杂环境中的性能。
除了数据并行训练方法,还有模型并行和管道并行等其他并行训练方法(详见 Llian Weng 的博客[3])。模型并行是将模型分割到多个计算设备上,使得模型的不同部分可以同时处理,这对于那些单个设备无法承载的超大型模型尤其有用。而管道并行则是将模型分为几个阶段,各个 mini-batches 数据依次通过这些阶段进行处理,这样做可以实现计算与通信的并行,从而提高训练的整体效率和吞吐量。这些技术互为补充,共同优化了大规模训练场景下的资源利用。
02 Butterfly All-Reduce
Butterfly All-Reduce(Zhao和Canny,2013)技术有效地解决了传统全归约方法所面临的挑战。在这种技术中,每个参与的节点(共 N 个)都会将其本地数据分割成 N 份。然后,第 i 个节点会收集所有其他节点发来的第 i 份数据,进行汇总后,再平均分配回各个节点。
这种方法大幅降低了通信的负担,并提升了系统的可扩展性。 在分布式训练中,所谓的“world size”是指参与训练的总进程或设备数。这个参数对于决定如何在各个节点间聚合和同步数据起到了关键作用。
以下是对 Butterfly All-Reduce 技术的一个概念性实现示例:
这段代码展示了 butterfly all-reduce 技术如何在保持分布式系统同步的同时,有效利用并行处理的优势。
butterfly all-reduce 方法的优势在于,与传统的全归约技术相比,它能够显著降低通信成本,并且具有更好的可扩展性,因此非常适合用于大规模分布式系统。然而,这种方法也存在一些不足之处。例如,其实现过程较为复杂,性能可能会受到通信网络拓扑结构和网络状况的影响。另外,如果参与节点中的任何一个发生故障,可能会对整个系统的同步过程造成影响。
在某些特定的应用场景,尤其是联邦学习中,训练过程需要能够适应不稳定的网络带宽和不可靠的工作节点。联邦学习尤为复杂,因为它涉及多个持有敏感隐私数据的独立参与节点。这些情况要求我们必须采用稳健的策略,以确保模型训练的可靠性。接下来,我们将探讨一些方法,这些方法的目的是平衡收敛速度和系统的容错能力。
03 Gossip-Based Averaging
gossip-based averaging(Boyd等人,2005年)是一种去中心化的通信策略,其中的参与节点构建了一个稀疏的通信网络。每个节点定期从邻近节点获取参数,并将其与自己的本地参数进行结合。这种方式减轻了参数服务器(parameter servers)带来的通信压力,但也意味着每个节点可能会使用不同的本地参数进行计算。
gossip-based averaging 的收敛特性深受通信网络结构的影响。以下是一个简单的 gossip-based averaging 实现示例:
gossip-based averaging 具有以下优势:
- 减少通信瓶颈:由于不需要集中的参数服务器,gossip averaging 大幅降低了通信拥堵,使得参数更新更加高效。
- 可扩展性:这种方法的去中心化特点使得它在扩展性方面表现出色,能够轻松应对参与节点数量的增加,而不会产生过多的额外开销。
- 容错性:分布式的设计提升了系统的容错能力,即使有 worker 出现故障,也不会中断整个训练过程;其他 workers 仍可以继续通信和更新参数。
然而,我们也需要注意到这种方法可能带来的几个不足之处:
- 收敛速度降低:与集中式更新方法相比,gossip averaging 的收敛速度可能会较慢,因为参数的聚合并不频繁,每个 worker 可能需要基于不太新的数据进行计算。
- 参数更新存在分歧:由于每个节点使用的是不同的本地参数,这可能会导致参数更新存在分歧,进而影响收敛的稳定性和速度。
- 依赖通信图:gossip averaging 的效果在很大程度上受制于通信图的结构。如果图的连通性不佳或者结构不平衡,可能会影响到算法的整体性能。
综合来看,尽管 gossip-based averaging 这种去中心化的参数更新方法具有很大的潜力,但在实际应用中,我们需要根据具体的训练场景,权衡其利弊。
04 Moshpit Gradient Descent
Moshpit Gradient Descent(Ryabinin et al., 2021)[4]方法进一步发展了去中心化训练的理念,它允许 workers 在小型且独立的组内进行梯度平均。这种设计意味着,即使某个参与节点出现问题,影响的也仅限于其所在的小组,从而提高了整个训练过程的容错性,避免了全局训练的中断。
这些小组的动态构建对于保证训练的有效性至关重要。通过优化小组结构,该方法大幅减少了达到收敛所需的步骤数,因为 workers 可以在较小的团队内更高效地交换和更新梯度信息。这种自适应的分组策略有助于更好地利用现有资源,并在不同的网络环境下实现更优的性能表现。
以下是一个实施 moshpit gradient descent 的概念性框架:
moshpit gradient descent 的优势包括:
- 容错性:个别 worker 的故障只会影响其所在的小组,不会波及整个训练过程,其他小组可以继续正常训练。
- 资源利用效率:在较小的小组内进行更新,该方法能够灵活应对网络状况和 worker 可用性的变化,从而提升训练效率。
- 降低通信负担:由于通信仅限于小组内部,整体的通信量得以减少,这在带宽受限的情况下尤为有利。
然而,这一方法也存在一些不足之处:
- 收敛难题:小组结构的不断变化可能导致参数更新出现不一致,可能会使得训练的收敛和稳定性面临挑战。
- 管理复杂性增加:对小组进行动态管理和调整,无疑增加了训练流程的复杂性。为了找到最佳的小组配置,我们需要开发更复杂的机制。
- 可扩展性问题:较小的小组虽然有助于提高系统的容错性,但如果没有有效的管理,这种方法在大规模训练场景中的可扩展性可能会受限。
综合来看,moshpit gradient descent 作为一种去中心化训练的新方法,其潜力不容小觑。它在容错能力和资源利用效率上的优势,与面临的收敛难题和实施复杂性之间,实现了微妙的平衡。
05 DiLoCo: Inner-Outer Optimization
DiLoCo(Douillard等人,2023年)[5]带来了一种创新的 inner-outer 优化算法,旨在提高去中心化训练的效率。在这种算法中,每个计算节点在内部优化阶段,会利用局部的 AdamW 优化器进行多次参数更新。这样的设计让节点能够基于局部数据独立优化参数,而不必实时与其他节点同步。当完成了一定量(通常是500次左右的)局部更新后,便进入外部优化阶段,此时会同步所有节点的伪梯度(这些梯度是局部更新结果的汇总)。
这种做法巧妙地结合了局部和全局更新的优势,有望加快收敛速度并提升训练表现。DiLoCo 通过让节点先在局部优化参数,再与全局模型同步,充分发挥了两种更新策略的长处。
以下是对 DiLoCo 更新过程的概念性描述:
DiLoCo 最初由 Google DeepMind 实现,而现在一家新兴的初创公司 PrimeIntellect 也成功复现了这一方法。OpenDiLoCo(Jaghouar等人,2024年)[6]已在 GitHub[7] 上公开,借助 Hivemind 库[8]训练了一个 10 亿参数的模型。最近,PrimeIntellect 推出了自家研发的定制化基础设施[9],其中包含了诸多工程创新,如定制的 all-reduce 算法和通信协议。该公司目前正在训练一个名为 Intellect-1[10] 的 100 亿参数模型。我相信这项实验的结果将对我们突破现有模式产生深远影响。目前,大模型的训练还依赖于集中的计算资源。但未来,或许每个人都能为打造下一代领先的基础模型贡献力量。
06 SWARM: Fault Tolerance and Dynamic Task Assignment
SWARM 算法(Ryabinin等,2023年)[11]引入了一种新颖的分布式训练方法,允许每个工作节点在训练过程的后续阶段将其输出发送给其他工作节点。这种灵活的任务分配方式,使得计算能力较强的设备能够承担更多任务,从而在多样化的硬件环境中实现资源的最优配置。这种策略在计算资源波动较大的场景下尤为有效,可实现更均衡的工作量,减少闲置时间。
面对工作节点的故障,SWARM 算法展现了其容错能力,能够迅速将故障节点的任务转交给其他正常运行的节点。这一机制对于维持训练流程的连贯性至关重要,它有效减少了意外中断的影响,并确保了处理能力的及时补充。工作节点间的通信路径是随机且动态调整的,这使得算法能够根据网络状况或节点状态的变动实时调整。
通过这种自适应的通信方式,不仅数据流转更加高效,训练过程的稳定性也得到了加强。下面是 SWARM 通信实现方式的简化示例:
在这个示例中,每个活跃的工作节点随机选取一个相邻节点作为信息传递的对象,这样的去中心化交流模式能够实时适应当前系统的状态。SWARM 算法以其动态任务分配和强大的容错能力,在大规模机器学习场景中显著提高了分布式训练的效率和可靠性。
07 Conclusion
分布式去中心化训练为高效训练大规模神经网络提供了一个强有力的支撑。借助数据并行训练方法、butterfly all-reduce、gossip-based averaging 等手段,从业人员能够在各种环境中应对模型训练的难题。对于任何想要优化大规模 AI 系统性能的人来说,掌握这些技术至关重要。随着该领域研究的不断深入,了解这些方法将是发挥分布式训练全部实力的关键。本文并非涵盖所有分布式训练方法和最新研究进展,而是提供一个粗略的概览——因此,还请读者自行探索更多技术细节🤗。
Thanks for reading!
Hope you have enjoyed and learned new things from this blog!
About the authors
Robert Lange
Deep Learning PhD @TU Berlin. Research Scientist @Sakana.AI. ✍️ 2x Google DeepMind Intern
END
本期互动内容 🍻
❓在分布式训练中,您认为最大的技术瓶颈是什么?是通信开销、收敛速度、还是系统的容错性,或是其他?
🔗文中链接🔗
[2]https://arxiv.org/abs/1602.01783
[3]https://lilianweng.github.io/posts/2021-09-25-train-large/
[4]https://openreview.net/pdf?id=cwWfDHYpb1z
[5]https://arxiv.org/abs/2311.08105
[6]https://arxiv.org/abs/2407.07852
[7]https://github.com/PrimeIntellect-ai/OpenDiLoCo
[8]https://github.com/learning-at-home/hivemind
[9]https://github.com/PrimeIntellect-ai/prime
[10]https://www.primeintellect.ai/blog/intellect-1
[11]https://arxiv.org/abs/2301.11913
原文链接: