MHA -> GQA:提升 LLM 推理效率

发布于 2025-1-13 11:35
浏览
0收藏

一、背景

我们在之前的文章中详细分析过 GQA 相比 MHA 的推理优势(省显存、计算强度高),不过 GQA 有可能导致精度的损失,因此早期的一些不太大的 LLM 会使用 MHA。针对这个问题有两种优化思路:

  • 将 MHA 转换为 GQA,长短序列都适用。
  • 在长序列场景使用 Token 稀疏化方案或者结合投机采样策略。​

本文中我们介绍一个将 MHA 转换为 GQA 的工作,不过论文的实验还偏少,效果也不是非常好;此外,最新的模型基本都在预训练阶段默认采用 GQA(LLaMA3 8B、LLaMA3.2 3B 以及 Microsoft 的 Phi 系列模型等),降低了本文工作的应用场景。

对应的论文:[2412.20677] Align Attention Heads Before Merging Them: An Effective Way for Converting MHA to GQA [1]

相关工作也可以参考我们以前的文章:

二、摘要

LLM 在多种自然语言处理任务中展现出卓越性能。然而,随着模型规模与输入序列长度的增长,KV Cache 的急剧膨胀显著拖慢了推理速度。鉴于此,作为 MHA 的替代方案,GQA 已被广泛引入 LLM。本研究提出了一种低成本方法,可将 MHA 模型按任意 KV Head 压缩比修剪为 GQA 模型。

该方法基于 L0 掩码逐步剔除冗余参数。此外,在不改变模型的前提下,对注意力头施加正交变换,以在修剪训练前提升 Attention Head 间的相似度,从而进一步优化模型性能。本方法兼容RoPE,意味着训练后的模型能完全适配主流标准 GQA 框架。实验表明,仅通过监督微调,提出的策略即可将 LLaMA2-7B 模型的 KV Head 压缩高达 87.5%,且性能损失极小。

三、引言

如下 3.1 和 3.2 部分在我们之前的文章中有相吸介绍:​​​LLM 推理的 Attention 计算和 KV Cache 优化:PagedAttention、vAttention 等​​。

3.1 MHA Attention 计算

如下图所示为标准的 LLM Decoding 阶段的 Multi-Head Attention(MHA)计算,其中的 D 表示 hidden size,H 表示 Head 个数,L 表示当前是在序列的第 L 个 Token。可以看出:

  • Batch Size 为 1时,图中红色绿色蓝色处的矩阵乘法全部为矩阵乘向量,是明显的 Memory Bound,算术强度不到 1。
  • Batch Size 大于 1时(比如 Continuous Batching):
  • 红色蓝色部分:因为是 Weight 乘以 Activation,所以不同的 Request 之间可以共享 Weight。这里变成矩阵乘矩阵,并且 Batch Size 越大,算术强度越大,也就越趋近于 Compute Bound(FFN 层也类似)。
  • 绿色部分:这里 Q、K 和 V 的 Attention 计算,是 Activation 乘以 Activation,所以不同的 Request 之间没有任何相关性。即使 Batching,这里也是Batched 矩阵乘向量,并且因为序列长度可能不同,这里不同 Request 的矩阵乘向量是不规则的。也就是说,这里算术强度始终不到 1,是明显的 Memory Bound。

MHA -> GQA:提升 LLM 推理效率-AI.x社区

从上可以看出,通过 Continuous Batching 可以很好的将 Memory Bound 问题转变为 Compute Bound,但 Q、K 和 V 的 Attention 计算的算术强度却始终小于 1。根据 Amdahl 法则,如果系统中有一部分无法优化,即使把其他部分优化到可以忽略,不可优化的部分也会决定整个系统的性能上限。不幸的是,Sequence Length 越长,这里的计算量就越不可忽略。

根据模型配置信息可以估算出模型中 Q、K 和 V 的 Attention 计算与其他矩阵计算的比例大约为 (L+D)/(12*D)(PS:准确值需要根据具体的模型参数计算)。也就是说,当序列长度 L 等于 12 倍的 hidden size 时,两部分的计算量相当,即使其他矩阵计算优化到 0,加速比也只有 2x。比如 LLaMA 2 7B 的 hidden size 为 4K,当序列长度达到 44K 时,两部分的计算量相当,要优化的重点也会很不一样,这也是很多长序列相关工作会在 Attention 部分采用稀疏 Attention 的一个重要原因。

3.2 GQA Attention 计算

早期通常只有比较大的模型才会采用 GQA([2305.13245] GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints),比如 LLaMA -2 70B,而 LLaMA-2 7B/13B 都没有采用 GQA。然而,LLaMA-3 8B 中也用上了 GQA,甚至其他更小的模型也在将 MHA 替换为 GQA。

  • 使用 GQA 有个非常大的好处:在推理阶段可以显著降低 KV Cache 的大小,比如,相比 32 个 KV Head 的 MHA,32 个 Query Head,8 个 KV Head 的 GQA 的 KV Cache 大小可以降低到 MHA 的 8/32=1/4,这也为更大的 Batch Size 提供了空间,可以进一步提升吞吐
  • 除此之外,还有一个比较大的好处:可以明显提升 Q、K 和 V 的 Attention 计算的算术强度。此时虽然不同的 Request 之间同样不能共享,但是同一个 Request 中的不同 Head 可以共享,比如 4 个 Query Head 共享 1 个 KV Head,则算术强度就会接近于 4,也可以更充分发挥 Tensor Core 的算力。

MHA -> GQA:提升 LLM 推理效率-AI.x社区

使用 MHA 时,Q、K 和 V 的 Attention 计算可以使用 CUDA Core 也可以使用 Tensor Core。由于 Tensor Core 要求矩阵的 Shape 是 8 的整数倍,如果不满足就只能 Padding:

  • 对于MHA而言,其是矩阵乘向量,则有7/8 的计算是冗余的
  • 对于GQA而言,如果 4 个 Query Head 共享 1 个 KV Head,则 Attention 计算有 4/8 的计算是冗余的,如果8 个 Query Head 共享 1 个 KV Head,则没有计算的冗余。很多框架已经做了相关优化,比如 LMDeploy,TRT-LLM 的 XQA 等。
  • 此外,PagedAttention 的 KV Cache 是非连续存储的,导致即使使用 GQA 也无法利用 Tensor Core。

PS:对于 GQA 而言,理论上也可以期望 GPU 的 L2 Cache 能够缓存到共享的 Key 和 Value Cache,从而缓解 IO Bound 问题,然而实际上无法人为控制,不一定能达到理想的效果。

3.3 动机

作者从 C4 训练集采样了 128 个 Sequence,共 128*2048=262144 个 Token,评估了 LLaMA2-7B 模型中每个 Transformer Block 中 Attention Head 的 KV Cache 的相似性。

如下图 Figure 2 所示,分析发现,大多数 Head 之间的 KV Cache 几乎是正交的,仅有少数 Head 共享较高的相似度。这表明直接对投影矩阵进行均值化会导致性能显著下降,说明 Attention Head 之间存在重要的独特性。

MHA -> GQA:提升 LLM 推理效率-AI.x社区

根据之前 [2406.07056] Effectively Compress KV Heads for LLM [2] 的研究,KV Cache 的低秩性为优化提供了新思路:

  • 可通过正交变换对齐 Key 和 Value 的投影矩阵。
  • 这种方法降低了优化的难度,并为 MHA 转换为 GQA 提供了理论支持。

四、方案

4.1 网络转换

主要目的是:在剪枝训练之前,对模型进行转换,以增加同一组内不同 Attention Head 之间的相似性,从而提高模型优化的效率。具体的过程大概为:

  • 根据前述的方案,使用部分 C4 的训练集来收集相应的 KV Cache。
  • 基于余弦相似性或者欧氏距离,计算最优的正交矩阵。
  • 将计算得到的正交矩阵融合到对应的 Q、K、V 投影矩阵中,保证计算不变性。对于 Q 和 K 的投影矩阵,要考虑 RoPE 的场景,在子空间应用正交变换。

通过正交变换,可以使得同一组内不同 Attention Head 在特征空间中更加接近,从而在后续的剪枝训练过程中更容易找到合适的参数共享方式,提高模型的压缩效果和性能。

如下图 Figure 3 所示,作者展示了不同的 Block 中转换前和转换后的 KV Cache 相似性,可以看出,转换后相似性明显增加:

MHA -> GQA:提升 LLM 推理效率-AI.x社区

4.2 找到更好的分组方法

在获取了每对 Attention Head 之间的相似度评分后,可依据这些评分对 Attention Head 进行重新分组。将一个组的相似度评分定义为该组内每对 Attention Head 之间相似度评分的总和,而每种分组结果的总相似度评分则是所有组相似度评分的累加。

合理的分组方式可以使得同一组内的 Attention Head 在特征空间中更加相似,从而在剪枝时更容易找到合适的参数共享方式,提高模型的压缩效果和性能。

4.3 剪枝训练

主要目的是:通过剪枝训练,逐步将原始的 KV Head 转移到新的 KV Head 上,同时保持模型性能。如下图 Figure 1 所示,具体过程包括:

  • 添加新的投影矩阵:在每组内使用 Mean Pooling  初始化新的投影矩阵。
  • 应用 L0 掩码:引入 L0 掩码来控制原始 KV Head 和新 KV Head 之间的转换。初始时,掩码值为 1,表示使用原始 KV Head;在剪枝过程中,逐步将掩码值约束为 0,表示使用新的 KV Head。
  • 知识蒸馏:使用 KL 损失和 BiLD 损失,鼓励学生模型与教师模型的输出对齐,从而保持模型性能。

MHA -> GQA:提升 LLM 推理效率-AI.x社区

五、实验评估

如下图所示,作者在多个任务上进行评估,GQA-16(32 个 KV Head 变为 16 个) 时平均精度甚至有所提升。但是 GQA-8(压缩 4x)和 GQA-4(压缩 8x)时损失就比较大:

MHA -> GQA:提升 LLM 推理效率-AI.x社区

六、参考链接

  1. https://arxiv.org/abs/2412.20677
  2. https://arxiv.org/abs/2406.07056

本文转载自 AI闲谈​,作者: AI闲谈

已于2025-1-13 11:42:18修改
收藏
回复
举报
回复
相关推荐