微软 RetrievalAttention: LLM+ANN, LLM 推理速度与精度的平衡
一、背景
本文我们继续介绍一个针对超长上下文的 LLM 推理加速工作,同样是 Token 稀疏化的方案,来解决 LLM 在超长序列场景计算量大、GPU 显存消耗大的问题,不过结合了 ANN 检索,可以实现更高的精度。
对应的论文为:[2409.10516] RetrievalAttention: Accelerating Long-Context LLM Inference via Vector Retrieval
二、摘要
本文中作者提出了 RetrievalAttention,无需训练就可以加速 Attention 计算。为了利用 Attention 的动态稀疏特性,RetrievalAttention 在 CPU 内存中使用 KV Cache 构建近似检索(ANN)索引,并在生成过程中通过向量检索识别最相关的索引。由于 Query 向量和 Key 向量之间存在 Out-Of-Distribution(OOD)问题,现成的 ANN 检索仍然需要扫描 O(N) 数据(通常占所以 Key 的 30%)进行准确检索,无法利用高稀疏性。
为了解决这个挑战,RetrievalAttention 采用注意力感知向量检索算法,可以调整 Query 只访问 1-3% 的数据,从而实现亚线性时间复杂度。RetrievalAttention 大幅降低了长上下文 LLM 推理的成本,大幅降低 GPU 显存需求,同时保持模型准确性。特别的,RetrievalAttention 只需要 16GB 内存就可以在具有 8B 参数的 LLM 上支持 128K Token的推理,在单个 RTX4090(24GB)上可以在 0.188s 内生成一个 Token。
如下图 Figure 1 所示为本文方法与几种常见方案的对比(PS:可以看出,本文方案相比之前 Token 稀疏化方案,是在牺牲一定推理速度的情况下提升精度):
三、方法
3.1 背景
使用 ANN 来识别关键 Token 有个独特的挑战:当前大部分的 ANN 引擎都假设 Query 向量和 Key 向量满足相同的分布,以此来实现高召回率。作者在这篇论文中首次提出这种假设在 Attention 机制中不成立。Query 的这种 OOD 特性损坏了 ANN 的预期检索质量,从而导致不得不访问更多的数据来保持正确性,作者实验表明,为了维持可接受的准确率,至少需要扫描 30% 的 Key 向量。
如下图 Figure 2 所示:
- (a)Attention Score 具有非常高的稀疏性,64000 个 Token,只有不到 500 Token 的 Score 大于 10-6。
- (b)Q 和 Q 或者 K 和 K 的相关性很高,而 Q 和 K 的相关性很差,需要扫描 30% 左右的 Token 才能保证 0.8 左右的召回率。
- (c)同样说明了 Q 和 K 的距离比较远。
3.2 概览
本文的工作主要聚焦于 Token Decoding 阶段,会假设 Prefill 阶段已经执行完成,比如通过 Context Caching 方案或 Prefill 和 Decoding 分离方案。
如下图 Figure 3(a)所示为本文方案 RetrievalAttention 的概览,其利用 CPU 侧的 ANN 检索来实现近似 Attention 计算,为了支持长序列,也会将所有 KV Cache Offload 到 CPU 内存以便构建索引。如图(b)所示是为了解决 OOD 问题而采用的索引机制。
3.3 近似 Attention
具体来说,不使用完整的 Attention Score,而是采用最相关的 KV 向量来近似 Attention Score:
3.4 Attention 感知向量检索
对于每对 Key 和 Value,首先确定是放在 CPU Memory 还是 GPU Memory(方法见下一小节)。然后 Offload 到 CPU 内存的 Key 和 Value 会使用 Key 来构建索引,并使用 Query 来检索。
为了加速 Token 生成过程中的向量检索速度,RetrievalAttention 利用 Prefill 阶段的现有 Query 来指导 Key 向量的索引构建。如上图 Figure 3(b)所示,RetrievalAttention 显式的建立从 Query 向量到其最近的 Key 向量的连接(即精确的 K 个最近邻,或 KNN)。KNN 结果可以通过 GPU 高效计算,形成从 Query 向量分布到 Key 向量分布的映射。使用这种结构,Decoding 的 Query 向量查询时可以首先查询最近的 Query 向量,然后将其映射为 Key 向量。
因此,之前的 Query 向量充当了解决 OOD 问题的桥梁。然而,这种结构在内存开销和搜索效率方面仍然存在缺陷,因为除了 Key 向量之外,还需要存储和访问 Query 向量。为了解决这个问题,作者利用先进的跨模态 ANN 索引 RoarGraph 中的投影技术来消除 Query 向量。具体来说,通过使用 Query 向量和 Key 向量的连接关系,将 KNN 连接投影到 Key 向量中,从而有效地简化搜索。此外,此方法也允许对未来的 Query 向量进行高效的索引遍历。
作者实验结果表明,通过这种 Query 和 Key 的连接关系进行有效建模,向量数据库只需扫描 1-3% 的 Key 向量即可达到高召回率,与 IVF 索引相比,索引搜索延迟大幅降低 74%。
3.5 CPU 和 GPU 协同执行
为了利用 GPU 并行性加速注意力计算,RetrievalAttention 将注意力计算分解为两组不相交的 KV Cache 向量:GPU 上的可预测向量和 CPU 上的动态向量,然后将两部分 Attention 输出合并在一起作为完整的 Attention 输出。
具体来说,利用 Prefill 阶段观察到的模式来预测 Token 生成过程中持续激活的 KV 向量。与 StreamingLLM 类似,作者将固定的几个初始 Token 和最近窗口内的 Token 作为静态 Token,持久化在 GPU 上。RetrievalAttention 也可以适配更复杂的静态模式,以便实现低推理成本和高准确性的平衡。为了最大限度减少通过慢速 PCIe 的数据传输,RetrievalAttention 在 CPU 和 GPU 上独立计算 Attention,然后将其组合起来,这个灵感来自 FastAttention([2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness)。
四、评估
4.1 实验配置
机器包含三个:
- RTX 4090 GPU(24G 显存),Intel i9-10900X CPU(20 Core),128 GB 内存。
- A100 GPU(80GB 显存),AMD EPYC CPU(24 Core)。
- A100 GPU(80GB 显存),AMD EPYC 7V12 CPU(48 Core),1.72TB 内存。
模型包含三个:
- LLaMA-3-8B-Instruct-262K
- Yi-6B-200K
- Yi-9B-200K
对比框架包括:
- Full Attention 的 vLLM
- StreamingLLM
- SnapKV
- InfLLM
基准测试包括:
- ∞-Bench
- RULER
- Needle-in-a-haystack
4.2 长文本任务精度
如下图 Table 2 所示,本文提出的 RetrievalAttention 明显优于之前的方案,平均精度非常接近 Full Attention。当然,部分模型上 Flat(暴露检索索引数据) 会略好于 RetrievalAttention,不过差距不大。
4.3 时延评估
如下图 Table 6 所示,作者首先验证了本文提出的检索方式的有效性,可以看出,提出的 RetrievalAttention 相比 Flat 和 IVF 可以提供 4.9x 和 1.98x 的加速,证明了检索机制的有效性:
如下图 Table 4 所示,作者也与之前的其他稀疏化方案进行对比,可以看出,之前的方案往往采用固定的 Token 数,因此随着序列变长并没有明显增加时延,而本文的方法会略微增加。同时,本文方法推理 Latency 相比之前方法明显增加,大概是之前方法 Latency 的 3x-6x。然而其仍然明显低于 Full Attention(FlexGen) 的结果。相当于在效果和速度之间的折衷。
如下图 Table 7 和 Table 8 为在 A100 上的结果,结论类似,不过在 100K 和 200K 时其 Latency 会超过 vLLM:
五、参考链接
本文转载自 AI闲谈,作者: AI闲谈