只有4k窗口长度的大模型,也能阅读大段文本了!
普林斯顿的华人博士生的一项最新成果,成功“突破”了大模型窗口长度的限制。
不仅能回答各种问题,而且整个实现的过程全靠prompt就能完成,不需要任何的额外训练。
研究团队创建了一种名为MemWalker的树形记忆策略,可以突破模型本身的窗口长度限制。
测试过程中,模型阅读的最长文本包含了1.2万+token,成绩相比LongChat大幅提高。
相比于相似的TreeIndex,MemWalker可以进行推理并回答任何问题,而不是只做概括。
MemWalker的研发利用到了“分而治之”的思想,就此有网友这样评论:
每次我们让大模型的思考过程更像人类,它们的表现就会越好
那么,具体什么是树形记忆策略,又是如何用有限的窗口长度阅读长文本的呢?
一个窗口不够,就多开几个
模型上,MemWalker使用Stable Beluga 2作为基本模型,它是Llama 2-70B经过指令调优得到的。
在选择该模型之前,开发者对比了其与原始Llama 2的表现,并最终确定选用。
就像MemWalker这个名字一样,它的工作过程就像记忆流在行走。
具体来看,大致分为记忆树构建和导航检索两个阶段。
构建记忆树时,长文本会被分割成多个小段(seg1-6),并由大模型分别针对每一段做出总结,得到“叶子节点”(leaf nodes,summ1-6)。
分割时,每段的长度越长,层级就会越少,有利于后续检索,但其本身过长又会导致准确度下降,因此需要综合考虑确定每段长度。
作者认为,每一段合理的长度是500-2000token,而实验中使用的为1000token。
然后,模型递归地对这些叶子节点的内容再次进行总结,形成“非叶节点”(non-leaf nodes,summ7-8)。
二者的另一个区别是,叶子节点包含了原始信息,非叶节点只有概括得到的二级信息。
作用上,非叶节点用于导航定位答案所在的叶子节点,而叶子节点则用于推理出答案。
其中的非叶节点可以有多个层级,模型逐步进行总结概括,直到得到“根节点”,形成完整的树形结构。
记忆树建立完毕后,就可以进入导航检索阶段来生成答案了。
这一过程中,模型从根节点开始,逐一读取下一级子节点的内容,然后推理出应该进入这个节点还是返回。
决定进入这个节点之后,再次重复这样的过程,直到读取到叶节点。如果叶节点的内容合适则生成答案,否则返回。
为了确保答案的完整性,这个过程的结束条件并非发现了一个合适的叶节点,而是模型认为得到了完整答案,或者达到最大步数。
导航过程中,如果模型发现进入了错误的路径,还可以导航回退。
此外,MemWalker中还引入了工作记忆机制来来提高准确度。
该机制会将已经访问过的节点内容加入到当前内容的上下文中。
当模型进入一个新节点时,当前节点内容都会被加入到记忆中。
这一机制让模型在每一步都可以利用访问过的节点内容,避免重要信息的丢失。
实验结果显示,工作记忆机制可以将MemWalker的准确率提升10%左右。
而且,上面所说的过程只依靠prompt就能完成,不需要进行额外的训练。
理论上,只要有足够的算力,MemWalker可以阅读无限长的文本。
不过,记忆树构建时的时间和空间复杂度随着文本长度的增长是呈指数型的。
作者简介
论文第一作者是普林斯顿大学NLP实验室华人博士生Howard Chen。
清华姚班校友陈丹琦是Howard的导师,她今年在ACL上的学术报告也与搜索有关。
这项成果是Howard在Meta实习期间完成的,Meta AI实验室的Ramakanth Pasunuru,Jason Weston和Asli Celikyilmaz三位学者也参与了本项目。