MixAttention:跨层 KV Cache 共享 + 滑动窗口 Attention
一、背景
我们之前的文章中介绍过 Character.AI 的 LLM 推理最佳实践,其在 1 年多的时间里将推理成本降低了 33 倍。其中一个关键技术是对 KV Cache 的跨层共享以及与 Local Attention 的结合。本文我们介绍 MixAttention,其思路和上述方案完全一致,不过针对长文本场景做了更多实验和调整。
对应的论文为:[2409.15012] Inference-Friendly Models With MixAttention
LLM 稀疏化相关工作可以参考:
- SnapKV: KV Cache 稀疏化,零微调加速长序列 LLM 推理
- TriForce:KV Cache 稀疏化+投机采样,2.3x LLM 无损加速
- 33 倍 LLM 推理性能提升:Character.AI 的最佳实践
- 微软 MInference:百万 Token 序列,10x 加速
- MLKV:跨层 KV Cache 共享,降低内存占用
- MiniCache 和 PyramidInfer 等 6 种优化 LLM KV Cache 的最新工作
二、方案
2.1 Character.AI 方案
如下图所示为 Character.AI 的方案,左侧为标准的 Transformer Layer,全部是 Global Attentio;右侧为 Character.AI 的方案,结合了跨层 KV Cache 共享和 Sliding Window Attention:
- 蓝色的 1,7,13 使用 Global Attention,并且 7 和 13 共享 1 的 KV Cache。
- 绿色的 2,4,8,10 和红色的 3,5,6,9,11,12 使用 Local Attention,并且红色的 3 会共享绿色2 的 KV Cache,红色的 5 和 6 会共享绿色4 的 KV Cache。
2.2 本文方案
如下图 Figure 2 所示为本文 MixAttention 与标准 Transformer Attention 以及 Sliding Window Attention 的区别。基本与上述的 Character.AI 的方案一致,只不过共享的位置不太一样。其中红点表示被共享的 Global Attention,蓝点表示被共享的 Sliding Window Attention。
- MA:与 Character.AI 方案一致。
- MA-Offset:起始的几个 Layer 先使用 Sliding Window Attention,关注局部;然后才会有 Global Attention。
- MA-EndSlide:和 MA-Offset 相反,在结束的 Layer 也采用 Sliding Window Attention。如下图 Figure 3 所示。这个主要是为了评估最后一层 Global Attention 对长序列的影响有多大。
- MA-Pairs:Global Attention 也采用 Pair 的方式。在 MA 和 MA-Offset 只会有一层的全局 KV Cache,在 MA-Pairs 中会有多层的全局 KV Cache。
作者也探索了更多连续层共享 Global KV Cache 的方案,以 MA-Successive 为前缀,如下图 Figure 9 所示:
除此之外,作者还探索了没有共享 Global KV Cache 的方案,以 MA-NoShare 为前缀,如下图所示:
三、实验和结果
3.1 训练
训练分为 3 个阶段:
- Stage 1:101B Token 预训练,Max Sequence Length 为 4K,RoPE 的 theta 为 0.5M。
- Stage 2:9B Token 自然语言和代码数据,Max Sequence Length 扩展到 32K,RoPE 的 theta 扩展到 8M。
- Stage 3:0.5B 长文本合成数据,Max Sequence Length 依然是 32K。
3.2 评估
所有模型在前两个 Stage 上的 Loss 都非常接近,而在 Stage 3 有较大区别。如下图 Figure 4 所示,MA、Sliding Window Attention 和 MA-EndSlide 的效果明显差于其他模型,在长文本 RULER 评估上也有类似的结论。作者也分析了相关原因,MA 和 MA-EndSlide 的 Global Attention KV Cache(非共享)都是在第 1 层,而 MA-Offset 和 MA-Pairs 至少有一个 Global Attention KV Cache(非共享)在深层。
3.3 推理速度
如下图 Figure 8 所示,作者在单个 H100 GPU 上使用 SGLang 验证了不同模型的推理速度,使用 300 个 Prompt,输入长度 31K,输出长度 1K。可以看出,MA 相关的方案在速度上都有比较明显的提升,大约 2x-3x。此外,支持的最大 Token 数目也更多,不过其中 Sliding Window Attention 还没有优化,所以支持的最大 Token 数和标准 LLM 相同。
PS:这里的实验有点单薄,只在一个单一的数据场景,也没有测试不同压力下的性能。
3.4 总结
如下图所示,从各种评估中可以看出本文的 MA-Offset 和 MA-Pairs 在推理速度,长短文本任务上都获得了不错的结果,而标准的 MA 在长文本任务上性能较差。
四、参考链接
本文转载自 AI闲谈,作者: AI闲谈