文字中貌似不起眼的标点符号,竟然可以显著加速大模型的训练和推理过程?
来自华为、港大、KAUST和马普所的研究者,就提出了一种新的自然语言建模视角——SepLLM。
起因是团队发现某些看似无意义的分隔符,在注意力得分中占据了不成比例的重要地位。
于是,SepLLM通过将一段文本中的信息压缩进分隔符(比如逗号,句号等)中,真的实现了加速效果,并且可以让KV缓存减少一半。
自注意力机制的平方级复杂度,给计算存储需求和训练推理速度带来了不小的挑战。
为了降低推理的复杂度,大量节约KV Cache的稀疏化方法被提出。
然而这些方法大多是基于用户的问题或者提示来筛选有用的KV Cache。
这使得如果用户再提出一个新的问题,模型回答的精度可能下降,因为包含答案信息的KV已经在上一次压缩过程中被抛弃。
除此之外,免训练方法通常无法相应地从头训练或者后训练,导致了训练和推理的流程差异性。
更重要的是现在主流的稀疏注意力改进方法,本质上更多是一种针对KV Cache存储与计算的稀疏化管理,而不是对自然语言的自然且高效的建模。
用分隔符实现自然语言高效建模
SepLLM通过将一段文本中的信息压缩进分隔符(比如逗号,句号等)中,显著加速了大型语言模型的训练和推理过程。
这一发现基于一个新颖且关键的模式:某些看似无意义的分隔符,在注意力得分中占据了不成比例的重要地位。
如下图所示,注意力可视化显示出一定的稀疏性,并且在分隔符处注意力明显更大。
由此,可以将这些自然语言中分隔符所自然分割的语义段的信息有效地压缩进分隔符中,其他tokens直接丢弃,而不会造成信息损失。
除此之外,一般一个分割符所分割的语段的长度是有限且相对均衡的,因此用分割此语段的分隔符去浓缩语段信息,可以避免类似RNN当序列过长时而出现遗忘的问题。
因为这种基于分割符的语言建模视角反映了自然语言的自然而内在的稀疏性,而不是人为用类似block/cluster等概念预先定义的稀疏性,作者认为SepLLM可以作为大语言模型的原生稀疏注意力机制和原生基线模型。
具体来说,SepLLM的基础设计包含下列三种tokens:
- 初始tokens:使用稀疏注意力机制时,保留初始tokens可避免生成tokens的困惑度(ppl)显著增加。
- 分隔符tokens:看似“无意义”的分隔符tokens在给定输入上下文中比有语义意义的tokens获得更高的注意力分数。因此假设这些分隔符可压缩其分割的文本片段信息,在免训练(training-free)的场景中,基于此策略能在许多任务上取得与原始模型相似的结果;
- 相邻tokens:由于语言任务通常具有局部依赖性,相邻tokens有助于形成局部平滑和连贯的上下文,所以在模型中考虑相邻tokens。
在预训练或者后训练的过程中,强迫模型当前的token只能看到前文每个片段中代表该片段的分隔符,使片段信息被强制浓缩到分隔符中。
实际上,每个分隔符(逗号、句号、分号、问号等)都是具备其特有的语义的,它们是对其分割段落的最原生和最细粒度的收尾与总结。
训练阶段,不需要将输入上下文中所有tokens对应的Query向量与所有Key向量相乘,只需乘以掩码矩阵中突出显示元素对应的Key向量;
生成阶段对KV缓存的管理较为直观,只保留初始、分隔符和相邻tokens的KV Cache。
研究者还针对Streaming场景还提出了定制的设计,包括同时维护的四个专用缓存块(初始缓存、分隔符缓存、过去窗口缓存和局部窗口缓存)及其功能,定义了四个缓存的运行时使用量和相邻tokens数量的相关变量,并详细说明了缓存系统的预设超参数。
在Streaming序列生成过程中,SepLLM会按照一定规则填充和管理这些缓存,当缓存达到一定条件时会触发压缩操作。
算力缓存消耗均减少,推理速度也更快了
作者分析了KV Cache的平均使用情况,结果,SepLLM在免训练、预训练和后训练场景中都展现出了卓越的效率,首先进行一个简单总结:
- 训推效率提升:SepLLM在免训练、从头预训练和后训练中都展现出了卓越的效率。特别是在使用Llama-3-8B模型时,SepLLM在GSM8K和MMLU基准测试中减少了超过50%的KV缓存,同时保持了相当的性能表现。
- 无限长的流式处理能力:在无限长输入的流式的场景中,SepLLM能够有效处理高达400万甚至更多tokens的序列,同时保持一致的语言建模能力。
- 广泛的实验验证与理论分析:通过在多种任务,各种基础模型(Llama,Falcon, GPTNeoX等)和多种数据集上的广泛实验,SepLLM证明了其在不同设置下的有效性,包括免训练、预训练和后训练。除此之外,作者还提供了对SepLLM架构通用近似(Universal Approximation)的详细理论分析。
接下来看一下具体的实验数据。
KV缓存减少50%
基于Llama-3-8B模型,SepLLM实现了超过50%的KV缓存减少,推理开销/显存压力大大降低,同时下游任务的性能几乎没有损失。
SepLLM的数学逻辑推理能力(GSM8K)/综合知识面广度(MMLU)在免训练的场景下即可达到和Llama-3-8B几乎一样的性能。
基于Pythia模型的更多下游任务上的结果,也验证了SepLLM的优秀的计算和存储效率与卓越的推理精度。
支持400万+Tokens流式长序列生成
同时,SepLLM可以轻松处理400万+Tokens以上的超长流式(streaming)序列生成。
推理速度更快,困惑度更低
并且由于SepLLM优化了推理过程,生成速度更快,同时语言模型的困惑度也更低了,运行时的平均KV Cache同样有所减小。
训练FLOPs更低,速度/吞吐率更大
除了推理,训练过程也用更低的FLOPs消耗,实现了更大的速度和吞吐率。
预训练中,达到相同Loss的时间缩短1.26倍,并且达到1.53倍的训练吞吐率和训练加速比。
后训练中,SepLLM也可以在较短时间内通过后训练恢复到原始Full Attention的训练loss,为基于大模型的高效后训练提供了可能。
适配不同backbone模型架构
同时,SepLLM可以适配各种backbone模型架构。
其中包括比如Llama、Pythia、GPTNeoX、GPT2以及Falcon等等。
对于这些架构,SepLLM均能实现更低的平均运行时KV Cache、更短的推理时间,以及更低的困惑度。
各种参数量模型均适配
SepLLM还可以适配各种大小的模型。
从Pythia-160M到Pythia-1.4B、6.9B,Llama3-8B,Falcon-40B等等,SepLLM均能实现更低的平均运行时KV Cache、更短的推理时间和更低的困惑度。
最近,DeepSeek的NSA与月之暗面的MoBA让稀疏注意力机制受到了较大的关注,相较于上述工作采用固定token数来划分压缩区间,SepLLM根据原生语义来划分动态数量的token数。
研究者也针对静态和动态token数压缩做了讨论,在免训练场景中,基于SepLLM的动态压缩能在下游任务中达到更好的准确率。
目前SepLLM的代码库已经公开,支持高效的多节点分布式训练,并采用了加速注意力机制的模块Sep-Attention。
此外,它还支持多种现有的Fusion Operators,如fused rope和fused layer norm,以加速训练过程。
项目地址:https://sepllm.github.io/
论文地址:https://arxiv.org/abs/2412.12094
代码:https://github.com/HKUDS/SepLLM