MHA -> GQA:提升 LLM 推理效率
一、背景
我们在之前的文章中详细分析过 GQA 相比 MHA 的推理优势(省显存、计算强度高),不过 GQA 有可能导致精度的损失,因此早期的一些不太大的 LLM 会使用 MHA。针对这个问题有两种优化思路:
- 将 MHA 转换为 GQA,长短序列都适用。
- 在长序列场景使用 Token 稀疏化方案或者结合投机采样策略。
本文中我们介绍一个将 MHA 转换为 GQA 的工作,不过论文的实验还偏少,效果也不是非常好;此外,最新的模型基本都在预训练阶段默认采用 GQA(LLaMA3 8B、LLaMA3.2 3B 以及 Microsoft 的 Phi 系列模型等),降低了本文工作的应用场景。
对应的论文:[2412.20677] Align Attention Heads Before Merging Them: An Effective Way for Converting MHA to GQA [1]
相关工作也可以参考我们以前的文章:
- 微软 RetrievalAttention: LLM+ANN, LLM 推理速度与精度的平衡
- LLM 推理的 Attention 计算和 KV Cache 优化:PagedAttention、vAttention 等
二、摘要
LLM 在多种自然语言处理任务中展现出卓越性能。然而,随着模型规模与输入序列长度的增长,KV Cache 的急剧膨胀显著拖慢了推理速度。鉴于此,作为 MHA 的替代方案,GQA 已被广泛引入 LLM。本研究提出了一种低成本方法,可将 MHA 模型按任意 KV Head 压缩比修剪为 GQA 模型。
该方法基于 L0 掩码逐步剔除冗余参数。此外,在不改变模型的前提下,对注意力头施加正交变换,以在修剪训练前提升 Attention Head 间的相似度,从而进一步优化模型性能。本方法兼容RoPE,意味着训练后的模型能完全适配主流标准 GQA 框架。实验表明,仅通过监督微调,提出的策略即可将 LLaMA2-7B 模型的 KV Head 压缩高达 87.5%,且性能损失极小。
三、引言
如下 3.1 和 3.2 部分在我们之前的文章中有相吸介绍:LLM 推理的 Attention 计算和 KV Cache 优化:PagedAttention、vAttention 等。
3.1 MHA Attention 计算
如下图所示为标准的 LLM Decoding 阶段的 Multi-Head Attention(MHA)计算,其中的 D 表示 hidden size,H 表示 Head 个数,L 表示当前是在序列的第 L 个 Token。可以看出:
- 当Batch Size 为 1时,图中红色、绿色、蓝色处的矩阵乘法全部为矩阵乘向量,是明显的 Memory Bound,算术强度不到 1。
- 当Batch Size 大于 1时(比如 Continuous Batching):
- 红色和蓝色部分:因为是 Weight 乘以 Activation,所以不同的 Request 之间可以共享 Weight。这里变成矩阵乘矩阵,并且 Batch Size 越大,算术强度越大,也就越趋近于 Compute Bound(FFN 层也类似)。
- 绿色部分:这里 Q、K 和 V 的 Attention 计算,是 Activation 乘以 Activation,所以不同的 Request 之间没有任何相关性。即使 Batching,这里也是Batched 矩阵乘向量,并且因为序列长度可能不同,这里不同 Request 的矩阵乘向量是不规则的。也就是说,这里算术强度始终不到 1,是明显的 Memory Bound。
从上可以看出,通过 Continuous Batching 可以很好的将 Memory Bound 问题转变为 Compute Bound,但 Q、K 和 V 的 Attention 计算的算术强度却始终小于 1。根据 Amdahl 法则,如果系统中有一部分无法优化,即使把其他部分优化到可以忽略,不可优化的部分也会决定整个系统的性能上限。不幸的是,Sequence Length 越长,这里的计算量就越不可忽略。
根据模型配置信息可以估算出模型中 Q、K 和 V 的 Attention 计算与其他矩阵计算的比例大约为 (L+D)/(12*D)(PS:准确值需要根据具体的模型参数计算)。也就是说,当序列长度 L 等于 12 倍的 hidden size 时,两部分的计算量相当,即使其他矩阵计算优化到 0,加速比也只有 2x。比如 LLaMA 2 7B 的 hidden size 为 4K,当序列长度达到 44K 时,两部分的计算量相当,要优化的重点也会很不一样,这也是很多长序列相关工作会在 Attention 部分采用稀疏 Attention 的一个重要原因。
3.2 GQA Attention 计算
早期通常只有比较大的模型才会采用 GQA([2305.13245] GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints),比如 LLaMA -2 70B,而 LLaMA-2 7B/13B 都没有采用 GQA。然而,LLaMA-3 8B 中也用上了 GQA,甚至其他更小的模型也在将 MHA 替换为 GQA。
- 使用 GQA 有个非常大的好处:在推理阶段可以显著降低 KV Cache 的大小,比如,相比 32 个 KV Head 的 MHA,32 个 Query Head,8 个 KV Head 的 GQA 的 KV Cache 大小可以降低到 MHA 的 8/32=1/4,这也为更大的 Batch Size 提供了空间,可以进一步提升吞吐。
- 除此之外,还有一个比较大的好处:可以明显提升 Q、K 和 V 的 Attention 计算的算术强度。此时虽然不同的 Request 之间同样不能共享,但是同一个 Request 中的不同 Head 可以共享,比如 4 个 Query Head 共享 1 个 KV Head,则算术强度就会接近于 4,也可以更充分发挥 Tensor Core 的算力。
使用 MHA 时,Q、K 和 V 的 Attention 计算可以使用 CUDA Core 也可以使用 Tensor Core。由于 Tensor Core 要求矩阵的 Shape 是 8 的整数倍,如果不满足就只能 Padding:
- 对于MHA而言,其是矩阵乘向量,则有7/8 的计算是冗余的。
- 对于GQA而言,如果 4 个 Query Head 共享 1 个 KV Head,则 Attention 计算有 4/8 的计算是冗余的,如果8 个 Query Head 共享 1 个 KV Head,则没有计算的冗余。很多框架已经做了相关优化,比如 LMDeploy,TRT-LLM 的 XQA 等。
- 此外,PagedAttention 的 KV Cache 是非连续存储的,导致即使使用 GQA 也无法利用 Tensor Core。
PS:对于 GQA 而言,理论上也可以期望 GPU 的 L2 Cache 能够缓存到共享的 Key 和 Value Cache,从而缓解 IO Bound 问题,然而实际上无法人为控制,不一定能达到理想的效果。
3.3 动机
作者从 C4 训练集采样了 128 个 Sequence,共 128*2048=262144 个 Token,评估了 LLaMA2-7B 模型中每个 Transformer Block 中 Attention Head 的 KV Cache 的相似性。
如下图 Figure 2 所示,分析发现,大多数 Head 之间的 KV Cache 几乎是正交的,仅有少数 Head 共享较高的相似度。这表明直接对投影矩阵进行均值化会导致性能显著下降,说明 Attention Head 之间存在重要的独特性。
根据之前 [2406.07056] Effectively Compress KV Heads for LLM [2] 的研究,KV Cache 的低秩性为优化提供了新思路:
- 可通过正交变换对齐 Key 和 Value 的投影矩阵。
- 这种方法降低了优化的难度,并为 MHA 转换为 GQA 提供了理论支持。
四、方案
4.1 网络转换
主要目的是:在剪枝训练之前,对模型进行转换,以增加同一组内不同 Attention Head 之间的相似性,从而提高模型优化的效率。具体的过程大概为:
- 根据前述的方案,使用部分 C4 的训练集来收集相应的 KV Cache。
- 基于余弦相似性或者欧氏距离,计算最优的正交矩阵。
- 将计算得到的正交矩阵融合到对应的 Q、K、V 投影矩阵中,保证计算不变性。对于 Q 和 K 的投影矩阵,要考虑 RoPE 的场景,在子空间应用正交变换。
通过正交变换,可以使得同一组内不同 Attention Head 在特征空间中更加接近,从而在后续的剪枝训练过程中更容易找到合适的参数共享方式,提高模型的压缩效果和性能。
如下图 Figure 3 所示,作者展示了不同的 Block 中转换前和转换后的 KV Cache 相似性,可以看出,转换后相似性明显增加:
4.2 找到更好的分组方法
在获取了每对 Attention Head 之间的相似度评分后,可依据这些评分对 Attention Head 进行重新分组。将一个组的相似度评分定义为该组内每对 Attention Head 之间相似度评分的总和,而每种分组结果的总相似度评分则是所有组相似度评分的累加。
合理的分组方式可以使得同一组内的 Attention Head 在特征空间中更加相似,从而在剪枝时更容易找到合适的参数共享方式,提高模型的压缩效果和性能。
4.3 剪枝训练
主要目的是:通过剪枝训练,逐步将原始的 KV Head 转移到新的 KV Head 上,同时保持模型性能。如下图 Figure 1 所示,具体过程包括:
- 添加新的投影矩阵:在每组内使用 Mean Pooling 初始化新的投影矩阵。
- 应用 L0 掩码:引入 L0 掩码来控制原始 KV Head 和新 KV Head 之间的转换。初始时,掩码值为 1,表示使用原始 KV Head;在剪枝过程中,逐步将掩码值约束为 0,表示使用新的 KV Head。
- 知识蒸馏:使用 KL 损失和 BiLD 损失,鼓励学生模型与教师模型的输出对齐,从而保持模型性能。
五、实验评估
如下图所示,作者在多个任务上进行评估,GQA-16(32 个 KV Head 变为 16 个) 时平均精度甚至有所提升。但是 GQA-8(压缩 4x)和 GQA-4(压缩 8x)时损失就比较大:
六、参考链接
本文转载自 AI闲谈,作者: AI闲谈