随着大型语言模型(LLM)规模和复杂性的持续增长,高效推理的重要性日益凸显。KV(键值)缓存与分页注意力是两种优化LLM推理的关键技术。本文将深入剖析这些概念,阐述其重要性,并探讨它们在仅解码器(decoder-only)模型中的工作原理。
随着大型语言模型(LLM)规模和复杂性的持续增长,高效推理的重要性日益凸显。KV(键值)缓存与分页注意力是两种优化LLM推理的关键技术。本文将深入剖析这些概念,阐述其重要性,并探讨它们在仅解码器(decoder-only)模型中的工作原理。

常规推理机制
首先,我们通过一个简单的例子来理解Transformer模型中典型的推理过程。假设我们需要生成短语:
“The quick brown fox jumped”
以下是常规推理的简化实现:
以下是逐步生成的过程:
冗余计算:观察上述代码可以发现对于每个新生成的token:
- 需要为所有先前的token重新计算K和V矩阵。
- 矩阵的大小随着token数量的增加而增大。
- 存在大量不必要的重复计算。
KV缓存机制
当使用Transformer模型生成文本时,通过缓存键(K)和值(V)矩阵,可以显著优化推理过程。下图展示了KV缓存的工作原理:

在上图中:
- q_new表示最新token的查询向量。
- K_prev和V_prev是从先前计算中缓存得到的键和值矩阵。
- k_new和v_new仅为当前新token计算。
- 蓝色箭头表示如何利用缓存值和新值计算注意力。
以下是KV缓存的实现示例:
以下是逐步生成的过程:

比较有无KV缓存的推理计算
内存需求与挑战
我们来看一个使用典型模型参数的实际例子:
- 序列长度: 4096
- 层数: 32
- 注意力头数: 32
- 头维度: 128
- 精度: FP16 (2 bytes)
每个token所需的内存:
KV缓存的低效性
尽管KV缓存显著提高了计算效率,但它也带来了内存管理方面的挑战。以下是三种主要的内存低效类型:

内部碎片
- 由因未知输出长度而导致的过度分配引起。
- 示例:在图像中,2040个槽位从未被使用。
- 影响:可能浪费高达60-80%的已分配内存。
- 解决方案:更精确的输出长度估计或动态分配策略。
预留浪费
- 为将来的token生成而预留的内存。
- 在图像中显示为“3 slots future used (reserved)”。
- 维持生成连续性的必要措施。
- 可以通过更好地预测所需的未来槽位来优化。
外部碎片
- 由处理具有不同序列长度的多个请求导致。
- 在不同请求之间创建内存间隙。
- 解决方案包括内存碎片整理和智能请求批处理。
如上图所示,通常仅有20-40%的KV缓存被用于存储实际的token状态。
分页注意力:解决内存低效的方案
为了应对这些内存挑战,可以采用分页注意力机制。
分页注意力是一种用于有效处理Transformer模型中长序列的技术,它通过将注意力计算分解为更小、更易于管理的“页”或“块”来实现。这种方法降低了内存消耗和计算复杂度,从而能够处理原本因过大而无法放入内存的序列。
以下是逐步生成的过程:
为何需要分页注意力?
- 内存约束:由于注意力矩阵的规模与序列长度呈平方关系,Transformer模型在处理长序列时面临严重的内存限制。
- 长序列处理:在诸如语言建模或文档摘要等任务中,序列可能非常长。
- 效率:通过以分页的方式处理注意力计算,可以将内存使用量保持在一个常量水平,从而不受序列长度的影响。
分页注意力如何工作?
- 分割序列:将输入序列分割成更小的块或页。
- 局部注意力:在每个页内计算注意力。
- 跨页注意力:可选地,允许有限的跨页注意力,以捕获页之间的依赖关系。
- 滑动窗口:使用重叠的页来确保连续性。
上述实现仅限于局部注意力,跨页注意力和滑动窗口的实现超出了本文的范围,将在后续文章中详细介绍。
分页注意力的讨论
优势
- 内存效率:注意力计算被限制在页大小内,内存使用量保持恒定,不受总序列长度的影响。
- 计算效率:降低了注意力计算的复杂度。
- 可扩展性:能够处理原本无法放入内存的超长序列。
权衡与考虑
- 上下文信息受限:模型会丢失跨页的一些依赖关系,这对于需要全局上下文的任务可能很重要。
可能的解决方案:
- 重叠页:允许页之间重叠一定数量的token,重叠区域的token可以关注前一页的token。
- 分层注意力:使用更高层次的注意力机制来连接跨页的信息。
重叠页、分层注意力、跨页注意力和滑动窗口的完整实现超出了本文的范围。
以下实现仅捕获局部注意力,作为示例不应在实际应用中使用:
import numpy as np
embeddings = {
'The': np.array([1, 0, 0, 0]),
'quick': np.array([0, 1, 0, 0]),
'brown': np.array([0, 0, 1, 0]),
'fox': np.array([0, 0, 0, 1]),
'jumped': np.array([1, 1, 0, 0])
}
W_Q = W_K = W_V = np.array([[1, 0],
[0, 1],
[0, 0],
[0, 0]])
PAGE_SIZE = 2
class AttentionWithCache:
def __init__(self):
self.cached_K = None
self.cached_V = None
self.cached_K_pages = []
self.cached_V_pages = []
def softmax(self, x, axis=-1):
"""
为x中的每组分数计算Softmax值。
包含数值稳定性改进。
"""
x_max = np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x - x_max)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def compute_attention(self, input_words):
E = np.array([embeddings[word] for word in input_words])
K = E @ W_K
V = E @ W_V
Q = E[-1] @ W_Q
scale = np.sqrt(2)
scores = (Q @ K.T) / scale
attention_weights = self.softmax(scores)
output = attention_weights @ V
return output
def compute_attention_with_cache(self, input_words):
"""使用KV缓存计算注意力"""
new_word = input_words[-1]
e_new = embeddings[new_word]
K_new = e_new @ W_K
V_new = e_new @ W_V
if self.cached_K is None:
self.cached_K = K_new.reshape(1, -1)
self.cached_V = V_new.reshape(1, -1)
else:
self.cached_K = np.vstack([self.cached_K, K_new])
self.cached_V = np.vstack([self.cached_V, V_new])
Q = e_new @ W_Q
scale = np.sqrt(2)
scores = (Q @ self.cached_K.T) / scale
attention_weights = self.softmax(scores)
output = attention_weights @ self.cached_V
return output
def compute_attention_with_paging(self, input_words):
"""使用分页KV缓存计算注意力"""
new_word = input_words[-1]
e_new = embeddings[new_word]
K_new = e_new @ W_K
V_new = e_new @ W_V
total_tokens = sum(len(K_page) for K_page in self.cached_K_pages) + 1
current_page_idx = (total_tokens - 1) // PAGE_SIZE
if len(self.cached_K_pages) <= current_page_idx:
self.cached_K_pages.append([])
self.cached_V_pages.append([])
self.cached_K_pages[current_page_idx].append(K_new)
self.cached_V_pages[current_page_idx].append(V_new)
Q = e_new @ W_Q
K_current_page = np.array(self.cached_K_pages[current_page_idx])
V_current_page = np.array(self.cached_V_pages[current_page_idx])
scale = np.sqrt(2)
scores = (Q @ K_current_page.T) / scale
attention_weights = self.softmax(scores)
output = attention_weights @ V_current_page
return output
def compare_implementations():
print("原始实现:")
attention1 = AttentionWithCache()
for i in range(len(['The', 'quick', 'brown', 'fox'])):
words = ['The', 'quick', 'brown', 'fox'][:i + 1]
output = attention1.compute_attention(words)
print(f"处理 {words} 后的输出:")
print(f"Output: {output}")
print("\nKV缓存实现:")
attention2 = AttentionWithCache()
for i in range(len(['The', 'quick', 'brown', 'fox'])):
words = ['The', 'quick', 'brown', 'fox'][:i + 1]
output = attention2.compute_attention_with_cache(words)
print(f"处理 {words} 后的输出:")
print(f"Output: {output}")
print("\n分页注意力实现:")
attention3 = AttentionWithCache()
for i in range(len(['The', 'quick', 'brown', 'fox'])):
words = ['The', 'quick', 'brown', 'fox'][:i + 1]
output = attention3.compute_attention_with_paging(words)
print(f"处理 {words} 后的输出:")
print(f"Output: {output}")
print(f"页数: {len(attention3.cached_K_pages)}")
print(f"当前页大小: {len(attention3.cached_K_pages[-1])}\n")
if __name__ == "__main__":
compare_implementations()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.
- 94.
- 95.
- 96.
- 97.
- 98.
- 99.
- 100.
- 101.
- 102.
- 103.
- 104.
- 105.
- 106.
- 107.
- 108.
- 109.
- 110.
- 111.
- 112.
- 113.
- 114.
- 115.
- 116.
- 117.
- 118.
- 119.
- 120.
- 121.
- 122.
- 123.
- 124.
- 125.
- 126.
- 127.
- 128.
- 129.
- 130.
- 131.
- 132.
- 133.
- 134.
- 135.
- 136.
- 137.
- 138.
- 139.
- 140.
- 141.
- 142.
- 143.
- 144.
- 145.
- 146.
- 147.
- 148.
- 149.
- 150.
- 151.
- 152.
- 153.
- 154.
- 155.
- 156.
- 157.
- 158.
- 159.
- 160.
- 161.
- 162.
- 163.
- 164.
- 165.
- 166.
- 167.
- 168.
- 169.
- 170.
- 171.
- 172.
- 173.
总结
KV缓存和分页注意力是提升LLM推理效率和可扩展性的重要技术。KV缓存通过消除冗余计算来优化计算过程,而分页注意力则解决了处理长序列时面临的内存限制。
随着模型规模和复杂性的不断增长,这些优化技术对于实际应用变得至关重要。深入理解和有效实施这些技术,可以显著提升LLM部署的性能和效率。