训练大模型时,显存都哪去了? 原创

发布于 2024-11-19 12:41
浏览
0收藏

GPT-2(XL)有15亿个参数,使用16位精度,一个参数占用2个字节的内存,因此这些参数大约占用3GB的内存。

训练大模型时,显存都哪去了?-AI.x社区

按照如下超参数设置:

  • 优化器 → Adam
  • 批量大小 → 32
  • 变换层数量 → 48
  • 序列长度 → 1000

要想在单个GPU上训练GPT-2,所需的最小内存大概是多少?

训练大模型时,显存都哪去了?-AI.x社区

答案可能会吓到你。

在一个拥有32GB内存的单个GPU上,几乎无法训练一个3GB的GPT-2模型。

训练大模型时,显存都哪去了?-AI.x社区

但这怎么可能呢?内存都去哪了?让我们来了解一下。

模型在训练过程中有很多方面会持续占用内存。

#1)优化器状态,梯度,模型参数

混合精度训练广泛用于加速模型训练。

顾名思义,这个方法的思想是在训练过程中同时使用float16低精度(在卷积和矩阵乘法等操作中)和高精度(如32位浮点数,float32)。

这就是“混合精度”名称的由来。

前向传播和反向传播都使用16位浮点数表示权重和梯度。

因此,如果模型有Φ个参数,那么:

● 权重将占用2 * Φ字节的内存。

● 梯度将占用2 * Φ字节的内存。

这里的“2”表示每个参数占用2个字节的内存(16位)。

Adam 是最受欢迎的模型训练优化器之一。

虽然许多实践者仅仅因为它流行而使用它,但他们没有意识到,在训练过程中,Adam 会存储两种优化器状态来计算更新——梯度的动量和方差。

训练大模型时,显存都哪去了?-AI.x社区

因此,如果模型有Φ个参数,那么这两个优化器状态将消耗:

● 4 * Φ 字节用于动量。

● 另需 4 * Φ 字节用于方差。

这里的“4”表示每个参数占用 4 个字节的内存(32 位)。

训练大模型时,显存都哪去了?-AI.x社区

此外,反向传播结束时的更新仍然在32位精度下进行,以确保有效的计算。这导致:

● 另需 4 * Φ 字节用于模型参数。

让我们把它们加起来:

训练大模型时,显存都哪去了?-AI.x社区

这就是 16 * Φ,或者 24GB 的内存,远远高于 16 位参数所使用的 3GB 内存。

而且我们还没有考虑到所有的因素。

2#)激活值

对于像大型深度学习模型(如大语言模型,LLMs)来说,激活值在训练过程中占用了大量内存。

更确切地说,在GPT-2的一个Transformer块中计算的激活值总数是:

训练大模型时,显存都哪去了?-AI.x社区

因此,在所有的Transformer块中,总计就是:

训练大模型时,显存都哪去了?-AI.x社区

这是 GPT-2-XL 的配置:

训练大模型时,显存都哪去了?-AI.x社区

总共大约是 300 亿个激活值。由于每个激活值使用 16 位表示,所有激活值总共占用 60GB 的内存。

通过使用像梯度检查点(在上一章讨论过的)这样的技术,可以将内存消耗降低到大约 8-9GB,但这也会带来额外 25-30% 的计算开销。

除了可以计算的内存开销外,还有一些额外的开销,例如内存碎片化。

内存碎片化是指在分配的内存块之间存在小的未使用间隙,导致可用内存的低效使用。

训练大模型时,显存都哪去了?-AI.x社区

内存分配请求失败是因为没有足够的连续内存块可用。

在上述讨论中,我们考虑了一个相对较小的模型——GPT-2(XL),它有 15 亿个参数,与如今训练的模型规模相比非常小。

然而,这个讨论可能帮助你反思构建大规模语言模型(LLMs)时的固有挑战。很多人常说,GPT 模型只是简单地堆叠更多的层并使网络变得更大。

如果真是那么简单,大家都会在做了。从这个讨论中,你可能已经理解到,这并不像仅仅添加更多层那么简单。

即便是增加一层,也可能导致额外数 GB 的内存需求。多 GPU 训练是这些模型的核心技术,我们将在另一篇文章中讨论。


本文转载自公众号人工智能大讲堂 

原文链接:​​https://mp.weixin.qq.com/s/PFQZnqJJ-tjFcv6oSjaV-A​


©著作权归作者所有,如需转载,请注明出处,否则将追究法律责任
标签
收藏
回复
举报
回复
相关推荐