MLKV:跨层 KV Cache 共享,降低内存占用
一、背景
LLM 中 KV Cache 占据的显存越来越大,有很多方案开始尝试跨层共享 K 和 V,比如我们之前介绍的 YOCO、CLA 以及 Layer-Condensed KV Cache 等,本文介绍的方案也极其类似。
对应的论文为:[2406.09297] MLKV: Multi-Layer Key-Value Heads for Memory Efficient Transformer Decoding
对应的代码库为:https://github.com/zaydzuhri/pythia-mlkv
PS:感觉本文创新度明显不足,相关实验也非常少,只在一个 160M 模型测试,甚至没有测试 7B 模型。
二、摘要
Transformer 模型的自回归推理因为 KV Cache 的存在可以大幅降低计算量,但随着模型、Batch Size 以及序列长度的增长,KV Cache 大幅增加,导致可能存在内存瓶颈。本文中,作者引入了多层 KV(Multi-Layer Key-Value,MLKV)Cache,可以跨 Transformer Layer 实现 KV Cache 共享,以减少内存占用,甚至可以比 MQA 和 GQA 节约更多的内存占用。作者使用经过训练的 Pythia-160M 变体,针对各种 NLP 基准和推理能力的指标进行评估,表明 MLKV 可以以最小的性能损失显著降低内存使用量(???),与 MQA 相比,可以将 KV Cache 大小减少 6 倍。这些结果凸显了 MLKV 在部署大规模 LLM 模型方面的潜力。
三、方法
如下图 Figure 2 所示,其思路很简单,也和我们之前介绍过的几个工作很类似,主要区别如下:
- MHA:原始的 Multi Head Attention,每一层的每一个 Head 都有独立的 K 和 V。
- MQA:Multi Query Attention,每一层的所有 Head 共享 K 和 V.
- GQA:Grouped Query Attention,MHA 和 MQA 的折衷,每一层的 Head 分为多组,每一组共享 K 和 V.
- MLKV:多个层共享 K 和 V,并且可以与上述 MQA 和 GQA 兼容。
如下图 Table 2 所示为不同配置下总共 KV Head 的个数,参数量,以及 Loss:
四、结果
如下图所示为不同配置下在各种评估任务上的结果,可以看出在同等配置下是弱于 GQA 的,甚至弱于 MQA:
如下图是相应的显存占用,同样 Head 数的方案内存占用相同,Head 越少,内存占用越少:
如下图 Figure 5 所示,同样 Head 下 MLKV 的速度会更快一些,不过差距都不大:
四、参考链接
- [2406.09297] MLKV: Multi-Layer Key-Value Heads for Memory Efficient Transformer Decoding
- https://github.com/zaydzuhri/pythia-mlkv