MemLong:用于长文本建模的记忆增强检索

发布于 2024-9-12 11:21
浏览
0收藏

​一、结论写在前面

论文标题:MemLong: Memory-Augmented Retrieval for Long Text Modeling

论文链接:https://arxiv.org/pdf/2408.16967

LLMs在各个领域的最新进展取得了显著的成功。然而,由于注意力机制的二次时间和空间复杂性以及生成过程中键值缓存的内存消耗不断增加,处理长上下文仍然是LLMs的一个重大挑战。

论文提出了MemLong,一种高效且轻量化的方法,用于扩展大型语言模型(LLMs)的上下文窗口。其核心思想是将过去的上下文和知识存储在一个不可训练的记忆库中,并进一步利用这些存储的嵌入来检索块级别的键值(K-V)对,以输入到模型中。MemLong适用于任何仅解码器的预训练语言模型,通过结合(1)一个额外的记忆检索组件用于记忆和检索,以及( 2 )一个检索因果注意力模块用于整合局部和记忆信息。MemLong的记忆和检索过程如图1(b)所示。在生成过程中,超出模型最大处理长度的文本被存储为上下文信息在记忆库中。随后,在长文档中生成最近的一个文本块时,论文使用检索器显式地检索过去的相关信息,通过索引对齐获取额外的上下文信息。

MemLong提供了几个优势:(1)分布一致性:与之前在信息存储到记忆中时经历分布偏移的模型不同,MemLong确保缓存信息的分布保持一致。(2)训练高效:论文冻结模型的底层,只需微调上层,这大大减少了计算成本。在论文的实验中,在0.5B个token上微调一个3B参数版本的MemLong仅需八块3090 GPU运行八小时。(3)广泛的上下文窗口:由于只需记忆单层的K-V对,MemLong能够在单块3090 GPU上轻松扩展上下文窗口至80k个token。

大量实验表明,与其他领先的LLMs相比,MemLong在多个方面表现出优越的性能。在多个长上下文语言建模数据集上,MemLong优于OpenLLaMA和其他基于检索的模型。在检索增强的上下文学习任务中,MemLong比OpenLLaMA提高了多达10.2个百分点。    

二、论文的简单介绍

2.1 论文的背景

由于传统的注意力机制的二次方时间和空间复杂性,扩展上下文长度具有挑战性,这为涉及长序列任务的应用带来了显著的限制,例如长文档摘要和多轮对话。因此,LLMs通常被期望具备长时间的工作能力(即长上下文LLMs),以有效应对这些苛刻的场景。

MemLong:用于长文本建模的记忆增强检索-AI.x社区

   

图1:检索增强生成(RAG)和MemLong的记忆检索流程示意图。(a) 当检索信息的长度超过模型处理能力时,RAG甚至可能降低生成性能(黄色)。(b) 论文的方法利用外部检索器获取历史信息,然后将其作为K-v对传递给模型,而不是以文本形式传递。

为了解决计算瓶颈问题,已经进行了大量努力。第一类工作侧重于通过采用稀疏注意力操作来减少普通注意力机制的计算量。尽管这些方法可以将计算复杂度降低到大约 O(n),但通常会以模型容量为代价。因此,一些工作转向了记忆选择。这些方法,作为token级别的记忆选择,可能导致语义信息的截断。另一类最近的工作是检索增强语言建模。这些工作通常引入检索机制来增强模型处理长文本的能力。

然而,这些方法有几个缺点。首先,由于训练过程中模型参数的变化,存储在记忆中的信息可能会经历分布偏移。其次,这些方法通常需要重新训练,这在大型模型时代是不切实际的。最后,这些模型往往倾向于以牺牲预训练模型的原始能力为代价来处理长文本输入。为了解决先前研究的局限性,论文提出了以下问题:论文能否利用检索器的显式检索能力来近似模型内部的隐式检索过程?

2.2预备知识

2.2.1 任务定义

语言模型旨在定义一系列token的概率分布,从而有效地预测给定语言中序列的可能性。给定一个序列 x_1, ..., x_n,标准方法建模其概率为 p ( x_1, ..., x_n )= sum pθ ( x_i | x< i ),其中 x< i , := x_1, ..., x_i-1 是 x_i 之前的token序列。

与标准语言建模目标不同,论文不仅使用当前上下文进行下一个token预测,还利用外部检索获取相关信息并在模型的上层进行知识融合。具体来说,给定一个由 l个token组成的序列,每个块的大小为 ν=l/τ,论文将序列划分为 τ个不重叠的块,记为 C=( c_1, ..., cν )。相应地,其文本形式被划分为ν个文本块,记为 T=( t_1, ..., t_ν)。在每一步中,论文在下层对 c_i 进行因果语言建模,而在上层,论文对 t_i 进行细粒度可控检索以融合额外信息。完成此操作后,论文的语言建模目标变为

MemLong:用于长文本建模的记忆增强检索-AI.x社区

其中 R( t_i) 表示检索 t_i 所在块的邻近块。

2.2.2 模块与操作定义

如图 2 所示,Ret-Mem 模块由一个检索器(Retriever)和一个记忆组件(Memory)组成,用于信息交换。首先,论文将记忆组件定义为M,将检索器定义为 R,以及它们对应的操作M(.) 和 R(.)。此外,论文指定模型的维度为 d_model,检索器的维度作为 d_ret 。记忆模块包括两个部分:K-V 对和相应的表示嵌入。键和值的维度表示为R^d_model,而嵌入的维度表示为R^d_ret。必须强调的是,实际的检索过程涉及表示块的嵌入,而不是 K-V 对。检索器本质上是一个预训练的密集嵌入器,具有出色的表示能力。MemLong 使用它将每个块编码为表示嵌入。由于它为每个块生成一维表示向量,即使内存大小很大,内存占用也保持最小。    

MemLong:用于长文本建模的记忆增强检索-AI.x社区

图 2 :MemLong 的一个示例:在较低层,模型保持静态,对整个块 c_i执行因果语言建模,随后,c_i 以嵌入和 K-V 对的形式缓存。最后,上层被微调以协调检索偏好并整合检索到的内容。

2.3 MemLong

2.3.1 概述

如图 2 所示,每个步骤涉及一个块 c_i 的输入,其中该块的原始文本为 t_i。在模型冻结的较低层,对整个 c_i 应用标准因果注意力。对于较低层的最后一层,论文称之为记忆层。在每次遍历记忆层后,执行两个关键操作。第一个操作是检索,由红线表示,其中 t_i 用于获取最相关的 K-V 对。第二个操作,由蓝线表示,涉及缓存获得的 K-V 对及其相关块表示。在模型的上层,检索到的 K-V 对与当前输入上下文整合,随后调整模型参数以校准检索参考。后续部分将探讨 MemLong 框架的各个方面及其细节,包括检索器和动态内存管理 (Dynamic Memory Management),注意力重构 ( Attention Reformulation ),以及使用 MemLong 进行推理 (Inference with MemLong)。

2.3.2 检索器与动态内存管理

论文提供了一个关于检索过程和内存管理动态的全面解释。    

检索过程。鉴于论文的目标是用显式检索取代基于K-V对的传统kNN检索,论文旨在在可行的情况下预取所需信息,然后再缓存模型输入。具体来说,对于每个潜在的查询块c^q=c_i及其对应的文本块t^q=t_i,论文首先将其传递给检索器,然后获得一个表示嵌入r^q = R ( t^q )。随后,论文使用这个表示嵌入对\mathcal{M}中的嵌入进行检索,以获取所需的k个块级索引。论文计算检索表示r^q与存储在内存M中的嵌入之间的余弦相似度。最后,论文得到c^q的top-k索引z^q =Top K{Cos( r^q ) }。

记忆过程。记忆过程同步存储来自记忆层的K-V对以及先前计算的用于检索的表示嵌入,确保\mathrm{K-V}对的索引与其表示嵌入准确对应(见图2,右侧,蓝线)。对于每个可能的块记忆c^m=c_i及其对应的文本块t^m=t_i,论文将记忆过程分为两部分:第一部分详细说明如何缓存K-V对,第二部分解释如何存储相应的表示。首先,论文将c^m输入到MemLong中,并从记忆层获取输出。论文的记忆操作非常高效,因为它仅涉及存储检索所需的表示r^m=r^q,从而避免了冗余。在所有块对检索完成后,记忆操作——表示为M ( k, v ; r^m )——同步更新记忆,包括键值对及其对应的表示。

动态记忆更新。当记忆溢出时,论文使用计数器智能更新记忆。在论文的实验中,论文保留最新10%的记忆内容,因其可能具有相关性,丢弃最旧的10%,因其可能已过时,并根据检索频率优先处理中间的80%,删除最少访问的条目,直到记忆使用量降至50%。这种选择性修剪平衡了时效性和相关性,保留了有价值的信息并删除了不太相关的数据。

MemLong:用于长文本建模的记忆增强检索-AI.x社区

   

图3:检索因果注意力示意图。局部因果注意力应用于最近的上下文,而通过检索方法获得的块级K-V对由于其历史性质而能够实现双向注意力,而不会发生信息泄露。

2.3.3注意力重构

在模型的可训练上层中,论文重构了注意力机制以融合长期记忆。如图3所示,与传统的Transformer解码层使用多头注意力不同,论文提出了一种检索因果注意力机制,将其扩展为联合注意力机制,并提出了一种长期记忆融合过程,使得每个token既能关注局部上下文,也能关注具有完整和连续语义的块级过去上下文。下一层的隐藏状态H^l计算如下:

MemLong:用于长文本建模的记忆增强检索-AI.x社区

为了避免训练初期检索注意力分数o_m的干扰,论文采用了多注意力机制,遵循LLaMA-adapter

MemLong:用于长文本建模的记忆增强检索-AI.x社区

最后,论文将V和V连接起来得到H^l:

MemLong:用于长文本建模的记忆增强检索-AI.x社区

2.3.4 使用MemLong进行推理

当MemLong接收到超过长度的输入时,论文将其视为两个部分:前缀p和主体。论文将分别描述在推理阶段对长输入的编码和长输出的生成。当MemLong接收到长输入时,它首先将前缀分成多个不重叠的块,并从其内存层计算,这确保了参与注意力的token数量等于块大小,远小于输入长度。需要注意的是,每个块是相互关联的(例如,第t个块需要处理前t-1个块的)。

第二步是根据块级检索表示选择与主内容最相关的k个块,并获取它们的关键和值表示。在此之后,对于上层检索层,检索的注意力窗口相当于k * τ,这也小于输入长度。最后,高效地执行长度受限的因果注意和检索注意。    

2.4 实验

论文在需要内存中长上下文处理的各项任务上评估论文提出的MemLong模型:(a) 长上下文语言建模和检索增强语言建模;(b) 能够处理内存中大量演示示例的可扩展上下文学习。

2.4.1 实现细节

训练细节。论文使用OpenLLaMA-3B作为预训练的骨干LLM,采用旋转位置编码(rotation position coding)。由于硬件限制,论文选择使用LoRA技术训练论文的模型。骨干LLM具有L=26, H=32, d=100的架构。除非另有说明,论文使用第13层作为内存层,[14, 18, 22, 26]层作为检索增强层。检索增强适应的训练仅在0.5B个token上迭代,序列长度为1024。Mem-Long的可训练参数来自14到26层。论文利用slimpajama数据集采样作为论文的训练语料库。

位置重映射。在生成过程中,M中检索到的块级别K-V有多个。由于每一步检索的不确定性,论文需要将位置嵌入重映射到检索到的块。与之前的工作(Tworkowski et al., 2024)相同,局部上下文(最多2048个token)接收标准旋转位置编码,而内存键则被编码为在局部上下文窗口中具有位置0。

2.4.2 长上下文语言建模

论文首先在长上下文语言建模基准上评估MemLong,以评估其基本的语言建模能力。由于K-V缓存提供了显著的背景和上下文信息,MemLong能够快速检索相关的K-V缓存并充分利用它,从而在长上下文建模任务中增强模型的性能。

数据集。论文在四个广泛的文本基准数据集上对论文的模型进行了评估:英语书籍PG-19和BookCorpus,维基百科文章Wikitext-103,以及数学论文Proof-Pile。实验结果表明,所有数据集上的困惑度都有显著改善。论文的模型在从1024到32768个token的不同长度范围内进行了测试。通过利用外部检索器和内存,论文的模型在所有数据集上展示了显著的性能提升,且内存开销最小。

设置。按照(Yen et al., 2024),论文计算每个序列最后2048个token的困惑度。此实验设置旨在验证不同检索器大小对模型整体性能的影响。为了实现高效的细粒度检索,论文使用faiss工具包在GPU上构建精确搜索索引,以存储文本块的表示嵌入并执行高效检索。对于MemLong,论文将token分割并放入finetune-length = 1024的M中,用于进一步检索。

基线。在论文的实验中,论文采用OpenLLaMA-3B模型作为基线。为了确保公平比较,论文使用相同的LoRA配置并微调了模型在相同数量的slimpajama数据集上的数据。此外,论文比较了LongLLaMA-3B,该模型使用Focused Transformer(FoT)方法和5B token进行了微调。为了进行更全面的比较,论文还测试了两个7B模型:LLaMA-2-7B和LongLoRA-7B-32K,以及两个位置编码模型:Yarn-7b-128k和Phi3-128k。    

MemLong:用于长文本建模的记忆增强检索-AI.x社区

表1:不同上下文窗口扩展模型在PG19、Proof-pile、BookCorpus、Wikitext-103上的滑动窗口困惑度。所有实验都在一块3090 24GB GPU上进行。LongLLaMA-3B和MemLong-3Btoken为表示在没有Memory的情况下评估,LongLLaMA-3Btoken表示在无限Memory的情况下评估。论文还评估了MemLong在4K/32K Memory场景下的表现。"- / 6.95"表示模型在单GPU上导致内存不足(OOM)错误,而在双GPU上则产生相应结果。*

结果。结果如表1所示。论文采用困惑度(PPL)作为语言模型的评估指标。较低的PPL表示更强的语言建模能力。与两个完全微调的模型相比,OpenLLaMA-3B和LLaMA-2-7B,论文的模型在测试长度在其预训练限制内(OpenLLaMA-3B为2048,LLaMA-2-7B为4096)时,在多个数据集上表现出相当的性能。

然而,一旦测试长度超过这些预训练限制,论文的模型即使在微调长度1024和预训练长度2048之后,仍能继续降低困惑度,展示了其优越的泛化能力。

相比之下,OpenLLaMA-3B和LLaMA-2-7B模型无法泛化到超出其预训练长度的输入,并且由于注意力机制的二次复杂性,表现出显著增加的内存开销。论文还与LongLoRA进行了比较。尽管LongLoRA中提出的Shifted Sparse Attention显著减少了内存使用,但它也削弱了模型在短文本上的性能。

相比之下,LongLLaMA由于其内存使用无限增长,在测试长度变得过长时也会遇到OOM问题。位置编码模型具有强大的泛化能力。然而,这种方法的性能只能保证长距离生成性能不下降。与这些方法相比,MemLong利用外部检索器处理更长的输入token,并实现了更好的困惑度改进。同时,由于高存储效率,MemLong可以有效控制GPU的使用,避免OOM问题。

2.4.3 上下文学习

传统的上下文学习(ICL)将少量的非参数化示例与查询一同输入模型。然而,这些方法通常受限于模型的输入长度。在本实验中,由于MemLong可以将示例以参数化形式存储在其记忆中,论文主要研究MemLong是否能有效利用其记忆中存储的知识以增强其突现能力。结果如表2所示。    

与仅依赖非参数化知识的OpenLLaMA相比,在相同数量的上下文示例下,MemLong可以利用其记忆中存储的额外示例。随着记忆中示例数量的增加,性能进一步提升或保持稳定。在与LongLLaMA的对比分析中,论文观察到在相同的保留内存示例条件下,论文的模型在大多数数据集上表现优于LongLLaMA。值得注意的是,与LongLLaMA相比,论文的模型在训练参数(2亿对比0.3亿)和微调数据量(0.5亿对比5亿)方面显著减少。这凸显了论文模型在利用外部检索器进行信息获取方面的效率,展示了在资源大幅减少的情况下,能够更有效地综合和利用知识的能力。

MemLong:用于长文本建模的记忆增强检索-AI.x社区

2.5 Ablation Study

2.5.1 训练设置

在训练阶段,论文探讨了不同检索层对模型的影响,并考察了论文的方法是否能充分解决MemTrm中讨论的分布偏移问题。如前所述,论文的方法为分布偏移提供了一种低成本的解决方案。如图4所示,棕色线(图片顶部的线条;训练方法类似于MemTrm,微调模型的所有参数,并且在内存层之后的所有层都参与检索)在性能和拟合速度方面明显劣于论文所有其他方法(即使是设置最不合理的方法)。论文将在稍后分析推理阶段的性能。    

2.5.2 推理性能

MemLong:用于长文本建模的记忆增强检索-AI.x社区

图 4:训练阶段PPL的变化程度。y轴的指标为PPL。论文主要关注训练参数和检索层。

Q1:记忆长度是否影响模型的性能?如图5所示,论文对同一模型在不同记忆大小下的性能进行了检查,结果表明记忆容量与模型效率之间存在明显的相关性。趋势表明,记忆大小的增加会逐渐提升性能。此外,在记忆大小为65536时,模型能力经历了一个显著的飞跃。这表明,虽然扩展记忆提供了实质性的好处,但其有效性存在一个实际的上限,这可能受到数据分布细微差别的影响。

Q2:论文需要引入多少层额外的记忆信息?如图4所示(粉色线)和表3(RPL+TH)中展示的,当检索层的数量设置为[13,17,21,25]时,模型表现最佳。根据经验认为,如果在模型的所有上层都引入检索信息,会导致模型对局部上下文的注意力下降。因此,以适当的间隔选择检索层实际上可以增强模型的能力。    

MemLong:用于长文本建模的记忆增强检索-AI.x社区

表3:不同的检索层会影响MemLong的性能。token为的MemLong表示在没有Memory的情况下进行评估。所有使用Memory的方法的大小设置为32768。RA表示跨所有上层检索;TA表示不冻结参数的训练;RP表示跨较少上层检索,RPL表示跨更少上层检索。    

MemLong:用于长文本建模的记忆增强检索-AI.x社区

图 5:在不同记忆大小下评估不同数据集。在每个子图中,除记忆大小外,所有参数均相同。

本文转载自​AI帝国​,作者: 无影寺 ​​

收藏
回复
举报
回复
相关推荐