Transformer 很厉害,但并不完美,尤其是在处理长序列方面。而状态空间模型(SSM)则在长序列上的表现相当不俗。早在去年就有研究者提出可使用 SSM 替代 Transformer,参见文章《预训练无需注意力,扩展到4096个token不成问题,与BERT相当》,前些天基于 SSM 方法的 Mamba 更是异军突起,推理吞吐量达到了 Transformer 的五倍之多,参阅《五倍吞吐量,性能全面包围Transformer:新架构Mamba引爆AI圈》。
但实际上,SSM 和 Transformer 并不是非此即彼的两种架构,它们完全可以组合起来!
近日公布的一篇 NeurIPS 2023 论文《Block-State Transformers》就采用了这种做法,其不仅能轻松支持 65k token 长度的超长输入,而且计算效率还非常高,速度相比使用循环单元的 Transformer 足可提升十倍之多!这篇论文也得到了 Mamba 作者 Tri Dao 的点赞,他表示:「SSM 和Transformer 似乎可以互补。」
但在我们介绍这种新方法前,先简单说说 Transformer。在许多不同的自然语言处理(NLP)任务上,Transformer 的表现都非常出色。可以说 Transformer 已经很大相当程度上替代了循环神经网络。不仅如此,它也正在图像和视频等 NLP 之外的领域大展拳脚。
其成功的原因有很多,包括计算效率和架构层面的归纳偏差,这让它们非常适合在自然语言任务进行大规模训练。在计算方面,Transformer 能以并行方式处理输入序列的 token,从而使其能充分利用现代加速器硬件。此外,注意力机制让 Transformer 可以找到更长序列之间的关系,其方式是在推断下一个 token 时读取从过去 token 提取的所有信息。相比于 RNN 和 LSTM,自注意力有两个优势:(1) 存储信息以及将这些信息直接用作上下文的能力得到了极大提升,(2) 在更长序列上能更稳定地训练。
尽管 Transformer 相比 RNN 有很多优势,但它在输入序列长度的扩展上依然存在问题,其中涉及计算性能和质量等方面的原因。更进一步说,Transformer 的运行时间会随输入序列长度的增长成二次方增长,这会让训练这些模型的成本越来越高。
此外,众所周知使用注意力的 Transformer 在长输入分类任务上表现不佳。最基本的 Transformer 在长序列上训练时可能不稳定,而且其 token 重要度聚焦在当前时间步骤周围约 50 个 token 的局部感受野中。
近来,越来越多的研究表明状态空间模型(SSM)可以替代 Transformer,因为 SSM 可以捕获极长序列之中的依赖关系,同时还有更高的计算效率和更好的并行化能力。
尽管 SSM 依然属于自回归序列模型,但其底层的线性时间不变式动态系统可使用基于快速傅立叶变换(FFT)的可并行化卷积算子来高效地处理序列,而且这个过程的复杂度仅为 𝒪(𝐿 log 𝐿),其中 𝐿 是序列的长度。此外,借用在线函数近似的方法,通过推导循环更新规则,可以确保在长序列上保留过去的信息,甚至可达成千上万个时间步骤。在 Long-Range Arena 基准上,SSM 甚至超过了 Transformer 一大截,参阅机器之心报道《六项任务、多种数据类型,谷歌、DeepMind提出高效Transformer评估基准》。
尽管 SSM 在长程分类任务上很成功,但如果要用作通用语言建模的现成可用序列模型,SSM 还完全赶不上 Transformer。
近期又有研究《Long Range Language Modeling via Gated State Spaces》认为 Transformer 和 SSM 完全可以互补。
DeepMind 等机构提出的新架构 Block-State Transformer(BST)将强大的基于局部注意力的归纳偏差与长期上下文建模能力组合到了一起,做成了单一层。
论文地址:https://arxiv.org/pdf/2306.09539.pdf
据介绍,该模型能在处理长输入序列的同时整合注意力机制来预测下一个 token。相比于基于 Transformer 的层,BST 是完全可并行化的,能扩展用于更长得多的序列,同时速度还能快 10 倍。
在每一层 BST 中,有一个 SSM 将输入的整个序列映射进一个同样长度的「上下文」序列。这个 SSM 子层使用基于 FFT 的卷积。然后将这个上下文序列分成大小相等的上下文块,这个大小即为窗口长度 W;然后再将每个上下文块输入一个 Transformer 层,其注意力关注的是大小为 W 的子序列。之后对输入 token 嵌入块与对应的上下文状态块使用交叉注意力,如图 1 所示。
注意,通过将 SSM 用作一种上下文化的方法,就可以完全不需要序列循环,这样一来就能以完全并行的方式运行这种 SSM-Transformer 混合层。
最后的运行时间复杂度可以表示成一个和:𝒪(𝑊²)+𝒪(𝐿 log 𝐿),其中前一项表示 Transformer 子层的时间复杂度,后一项是 SSM 子层的时间复杂度。
只要有支持并行计算的硬件,相较于 Block-Recurrent Transformer 的 𝒪(𝐿𝑊),这是一个重大提升。此外,由于硬件施加的限制,SSM 在完整序列上的运行时间复杂度与 Block Transformer 在 token 块上的运行时间复杂度相当,这进一步意味着 BST 层不存在速度瓶颈。该团队使用包含数十万 token 的序列通过实验验证了这一点。
方法
这里研究的是通过仅解码器语言模型实现下一 token 预测的问题。
对状态空间的前置说明
状态空间模型可以分为两大类:
状态空间:结构化核S4、S5、S4D、DSS遵循卷积核的一种结构化初始化,方式是展开一种线性时间不变式(LTI)动态系统,如下所示:
其中的参数包括状态矩阵 𝚨∈ℝ^{N×N},向量 𝐁∈ℝ^{N×1}、𝐂∈ℝ^{1×N}、𝐃∈ℝ^{1×1}。SSM 会将一维的输入信号 u_k 映射成一维的输出信号 y_k。
显式参数化的过滤器。不同于结构化核,还可以将卷积核参数化为可训练的权重并优化它们。但是,这会导致性能很差,除非对这些核使用特定类型的正则化方法。替代 Transformer 的无注意力模型中也有使用可训练核的,比如 Hyena 涉及到沿核对权重进行指数衰减。
Block-State Transformer(BST)层
Block-State Transformer 层将 SSM 与 Block Transformer 组合到了一起。在每一次训练迭代中,都会从一个长文档采样一个包含 L 个 token 的序列。然后嵌入该 token 并将其馈送给模型。这个模型由堆叠的 Block-State Transformer 层构成。每一层 BST 都会选择性地包含一个 SSM 子层,其负责为 Block Transformer 层提供长程上下文,这与 Block-Recurrent Transformer(BRECT)单元的工作方式类似。这个 SSM 子层的输入是前一层的 token 嵌入序列,输出则是一个长度同样为 L 的序列。
这个输出经过了上下文编码,也就是说每个时间步骤的项目都可能包含有关该序列中元素之前的所有时间步骤的信息。他们从上下文序列收集一定数量 S 的「上下文状态」,并使得 S ≪ L。
这些上下文状态会被馈送给 Block Transformer,以替代 Block-Recurrent Transformer 中的「循环状态向量」。如图 1 右侧所示,后续操作保持不变,只是无需再运行 BRECT 单元的循环单元,因为现在是通过 SSM 来维护上下文。除了上下文状态,Block Transformer 的输入中还有长度 W 的 token 嵌入的块/窗口;然后在这个窗口与上下文状态上使用交叉注意力。然后将这个交叉注意力操作的输出与自注意力在输入嵌入上的输出连接起来,之后是一个简单的投影。
SSM 不仅能在更长时间尺度上保留信息,而且使用 SSM 来维持上下文状态以替代循环单元,可以得到计算效率更高的层。通过将 SSM 整合进 Transformer 层,可以移除循环部分,从而让 Block-State Transformer 层可以完全并行化。
上下文状态
尽管从技术上看,最新的 SSM 输出包含有关整个序列的信息,但仅从最后的状态检索单个 token 可能是不可行的。为了弥补这一点,该团队将一系列状态连接了起来,对应于最新的 token 块。这与 BRECT 采用的方法类似。这种表征可以通过冗余来确保可检索性和易访问性。
在新提出的方法中,上下文状态是使用 SSM 的输出构建的,并会被馈送给 Transformer 的注意力头。这些上下文状态的构建方式有很多。为了引导设计决策,该团队考虑了多种设计方案,包括使用单头(Single-Head)、多头(Multi-Head)或多过滤器(Multi-Filter)。其中单头设计见图 1。下图 2 则展示了多头和多过滤器的设计方案。
比较下来,多过滤器的记忆状态的冗余最少,多头次之,单头的冗余最大。
结果
该团队在 PG19、GitHub 和 arXiv 三个数据集上进行了实验,检验了新提出的 BST 在不同长度的英语文本、latex 科学文章和源代码上的效果。下表 1 总结了实验结果。
下图 3 则给出了长度泛化分析并报告了困惑度。实验中,新模型和基准模型的参数数量都约为 4 亿,训练时的序列长度为 4k,测试中的序列长度为 {512, 16k, 65k}。
可以看到,在 PG19、GitHub 和 arXiv 上,当序列长度为 65k 时,BST:SH:S4-L 的困惑度最好。
在效率方面,下图 4 左给出了 BST 层在 GPU 上的基准测试结果。
可以看到 SSM 带来了非常显著的增长——比包含循环单元的 Block-Recurrent Transformer 快 6-11 倍;即使在序列长度达到 65k token 时,还依然能有 6 倍的提升,而这时候硬件就已经开始饱和了。当使用结构化的 SSM 时,计算复杂度与 SSM 的内部记忆状态大小 N 紧密相关。对于报告的性能,N = 16。
研究者表示,如果使用其它自动微分框架中近期引入的更快的针对硬件的 I/O 感知型实现,BST 方法的速度还能更快。
更多技术细节和实验结果参阅原论文。