大模型同样的上下文窗口,只需一半内存就能实现,而且精度无损?
前苹果ASIC架构师Nils Graef,和一名UC伯克利在读本科生一起提出了新的注意力机制Slim Attention。
它以标准多头注意力(MHA)为基准,对其中的value缓存处理过程进行了调整,实现了更少的内存占用。
具体来说,Slim Attention既可以让KV缓存大小减半,也可以在KV缓存大小不变的情况下让上下文翻倍,都不会带来精度损失。
此外,在内存带宽受限的场景下,它还可以将模型的推理过程加速1.5-2倍。
网友评价,Slim Attention虽然简单,但却是一个很酷的想法。
还有AI创业者评论说,这是一项重大突破,可能重塑对模型训练和部署的看法。
K-Cache is All You Need
在标准的MHA机制当中,对于输入X会通过线性变换,经由三个投影矩阵W_Q、W_K、W_V得到Q(query)、K(key)和V(value)三个矩阵。
在推理阶段,每个输入token计算得到的K和V向量都需要缓存起来,形成KV cache供后续token计算时使用。
Slim Attention的核心思路是,利用MHA中W_K和W_V通常都是方阵的性质,只存储K而不直接存储V,然后实时利用K计算出V。
△原始MHA(左)与改进版(右)对比
在训练阶段,Slim Attention与标准MHA一样,会对输入X计算Q、K、V三个矩阵,注意力计算和梯度回传也与标准MHA完全一致。
在W_K可逆的前提下,Slim Attention引入一个新的参数矩阵W_KV:
W_KV = W_K^(-1)·W_V
据此,可以得到:
V = X·W_V = X·W_K·W_K^(-1)·W_V = K·W_KV
推理过程则主要分为两个阶段——提示阶段(并行计算)和生成阶段(自回归)。
提示阶段与标准MHA一样,将输入的所有token并行计算Q、K矩阵,但不同的是,这里不直接计算V,而是将中间结果K缓存供后续使用。
生成阶段每个时间步生成一个新token,首先计算该时间步的Q向量q,然后基于q和之前时间步缓存的K矩阵,计算注意力得(即softmax的输入)。
在softmax之前,Slim Attention通过公式V = K · W_KV实时计算V矩阵。具体有两种方式:
- 直接计算V,然后将softmax结果与V相乘(矩阵乘法)得到注意力输出;
- 先将softmax结果与K相乘,然后再与W_KV相乘,当序列较长时这种方式更高效。
剩余流程(残差连接、前馈层等)与标准MHA一致,最后将当前步的k向量添加到K缓存中,供下一时间步使用。
总之,Slim Attention是标准MHA的精确数学重写,因此与近似方法不同,可确保准确率不会下降。
以此为前提,Slim Attention实现了KV缓存减半或上下文翻倍的效果。
前苹果架构师与UC伯克利本科生成果
Slim Attention的作者是AI初创公司OpenMachine的创始人兼CEO Nils Graef,以及UC伯克利在读本科生Andrew Wasielewski。
Nils的主业是机器学习加速器的架构和设计,曾发表两篇IEEE期刊论文和30多项专利,引用次数超过900次。
创立OpenMachine前,Nils在知名推理加速平台Groq(注意不是马斯克的Grok)担任芯片架构师。
更早的时候,他先后担任过谷歌ML加速器架构&设计工程师和苹果ASIC架构师。
Andrew Wasielewski是UC伯克利在读本科生,专业是物理和EECs(电气工程与计算机科学),预计将于明年毕业。
根据论文署名信息显示,Slim Attention的工作是Andrew在OpenMachine完成的。
去年7月,Nils和Andrew还与其他人合作,发表了一篇名为Flash normalization的论文,提出了一种更快的RNS归一化方式。
此外在Slim Attention的致谢中还提到,艾伦实验室的Dirk Groeneveld,以及SGLang三作谢志强,对其工作提供了有益讨论;Transformer作者之一、Character.AI创始人Noam Shazeer给出了积极反馈。
论文地址:https://arxiv.org/abs/2503.05840
参考链接:https://x.com/rohanpaul_ai/status/1901092052282339474