基于 Transformer 架构的大语言模型在 NLP 领域取得了令人惊艳的效果,然而,Transformer 中自注意力带来的二次复杂度使得大模型的推理成本和内存占用十分巨大,特别是在长序列的场景中。
此前,研究者们提出了线性 Transformer、Mamba、RetNet 等。这些方案可以大幅降低 Transformer 计算成本,并且取得媲美原有模型的精度,但是由于架构更换,模型重训练带来的巨大成本令人望而却步。
为了解决这一问题,最近的一篇论文提出了一种基于频域的大语言模型架构 — 帝江(源于山海经的一种神话生物,以跑得快而闻名),同时解决了现有大模型的两大痛点:推理成本和训练成本。
- 论文地址:https://arxiv.org/abs/2403.19928
- 开源链接:https://github.com/YuchuanTian/DiJiang
该论文基于频域自注意力变换核,寻找到一种原始自注意力的线性逼近,使得原有的 Transformer 模型可以经过少量数据(1/10-1/50)的微调,可以近乎无损地变形为论文提出的帝江模型。具体来说,在 LLaMA2-7B 上仅仅需要使用 40B 左右的训练数据,就可以取得最多 5 倍的推理加速,且在各个评测集上取得相当的精度。
DiJIang-7B 模型和 LLaMA-7B 的精度对比
DiJIang-7B 模型和 LLaMA-7B 的速度对比
研究背景
Transformer 架构自从推出以来,彻底革新了自然语言处理(NLP)领域,并在多种任务中取得了杰出成果。这一成功导致了大型语言模型(LLMs)主导的时代的到来,在这个时代中,Transformer 结构被放大以处理越来越复杂的任务。然而,这种规模的扩大也带来了巨大的计算需求,特别是由于需要每个 token 之间的计算的自注意力机制。
面对更高效 Transformer 模型的迫切需求,研究者们提出了线性 Transformer、Mamba、RetNet 等方案,虽然这些方案可以大幅降低 Transformer 计算成本,并且取得媲美原有模型的精度,但是由于架构更换,模型重训练带来的巨大成本令人望而却步。
然而,大多数现有的优化 Transformers 方法,特别是与优化注意力机制有关的,需要对模型从头重新训练。这一重新训练过程是一个巨大的挑战,特别是对于参数庞大的模型,需要大量的计算资源和时间投入。例如,像 LLaMA-7B 这样的大型模型的训练需要大约 8 万多 GPU hours。尽管有部分研究如 Performer 努力寻找注意力机制的快速近似方法,但这些方法在大型语言模型中还没有得到彻底的验证。
为了解决大型语言模型中快速注意力近似的问题,论文对现有的线性注意力方案和自注意力近似方案进行了彻底的分析。论文发现,这些方法中近似误差的主要来源是基于蒙特卡洛方法的采样。因此,论文提出采用加权拟蒙特卡洛采样来代替蒙特卡洛采样进行映射,论文进一步引入频域离散余弦变换(DCT)来作为拟蒙特卡洛采样的值,从而高效且准确地将 Transformer 的 query 和 key 映射到频域。使得注意力机制中的 softmax 操作可以被去除,达到线性的计算复杂度。论文还从理论上证明了,这种频域映射是与原始注意力机制的一个近似等效,从而使得帝江模型可以不需要从头开始训练,只需要少量数据就可以从 Transformer 的参数中进行微调继承。论文的实验表明,论文的方法达到了与原始 Transformer 相当的性能,但训练成本大大减少(<1/10),同时也受益于更快的推理速度(在不同模型上最高约 10 倍)。
方法介绍
论文首先回顾了 Attention 的计算方式:
其中是一句话中 token 的数目,d 是隐藏层的维度,传统的 Attention 计算复杂度是。
为了减少 Attention 的计算复杂度,线性 Attention 方案希望将 softmax 函数去掉,这样 K 和 V 的计算可以提前进行,从而使得计算复杂度变为,由于 n 通常要远大于 d,因此在变化后计算复杂度可以被大幅减小。例如,Performer 采用了 PRF 核来逼近原始 Attention 的计算,具体为:
然而,由于蒙特卡洛方案存在的近似误差,Performer 等方案常常要将隐藏层从维度映射为更大的维度,这导致了线性注意力带来的计算复杂度变为,使得计算加速的收益减少。
为了解决这个问题,论文首先提出一种基于加权拟蒙特卡洛的方案,具体的,论文提出了一种新的 WPFF 核映射:
和 PRF 映射不同,WPFF 核映射在两点上进行了改进:1. 将原有的随机映射 w 变为给定的均匀正交变换 v 和其模长部分 t ,即使用拟蒙特卡洛变换来代替蒙特卡洛变换,减少逼近误差从到。2. 使用加权矩阵 D 来对映射进行加权求和,减少蒙特卡洛映射的误差。
论文提供了理论证明,来表明提出的 WPFF 映射核是一种更优的映射方式,具体的证明内容详见论文附录:
基于 WPFF 核,论文又进一步对其进行改进,由于给定的均匀正交变换 v 可以使用任意的均匀正交变换,论文提出使用频域 DCT 变换来进行计算,由于 DCT 变换具有特殊的形式,其计算复杂度仅为,相比其他的正交变换要来的更低,最终,论文使用的 WDCF 映射为:
最终,帝江模型的自注意力计算被代替为:
帝江模型和传统自注意力计算的区别
上图展示了帝江模型和传统自注意力计算的区别,在 Transformer 的注意力机制中,key 和 value 的计算通过快速离散余弦变换(DCT)高效地映射到频域。这种映射有效地消除了 softmax 操作,从而显著降低了 Transformer 的计算复杂度。
实验结果
不同模型大小的对比
上表展示了提出的帝江模型在不同大小的 scale 上的结果,可以看到,提出的帝江模型可以取得和原始模型基本相同的精度,并且拥有更快的推理速度和更低的训练成本,显著解决了现有 LLM 遇到的训推成本过大的问题。此外,模型在 1B 的模型量级上超越了 1.3B 大小的 Mamba 模型。需要注意的是,尽管传统 Transformer 可以通过 Flash Attention 的方式进行进一步加速,但由于针对帝江模型的加速框架尚未开发,为了公平对比模型本身的速度,推理速度的测试都是在模型都不使用加速框架的前提下进行的。
与不同 Transformer 改进方案精度对比
论文还展示了帝江和其他 Transformer 模型的改进方案进行了进一步的对比,可以发现,帝江模型具有比其他模型更好的效果,这得益于其通过更好的核映射近似了原始的 Transformer 模型计算。
论文还同时提供了帝江 - 7B 模型的续写样例展示,可以看到,帝江 - 7B 的续写结果,和 LLaMA2-7B 相比毫不逊色,甚至条理性上要略胜一筹。
总结
论文提出了一种新的 LLM 架构:帝江,在 7B 以下的模型量级,所提出的模型可以大幅降低 LLM 所需的训练和计算成本,为未来 LLM 的高效部署提出了一种新的思路。帝江架构是否会在更大的模型与多模态 VLM 等其他 Transformer 的应用领域中大放光彩,让我们拭目以待。