Transformers学习上下文强化学习的时间差分方法 原创
上下文学习指的是模型在推断时学习能力,而不需要调整其参数。模型(例如transformers)的输入包括上下文(即实例-标签对)和查询实例(即提示)。然后,模型能够根据上下文在推断期间为查询实例输出一个标签。上下文学习的一个可能解释是,(线性)transformers的前向传播在上下文中实现了对实例-标签对的梯度下降迭代。在本文中,研究人员通过构造证明了transformers在前向传播中也能实现时间差异(TD)学习,并将这一现象称为上下文TD。在训练transformers使用多任务TD算法后展示了上下文TD的出现,并进行了理论分析。此外,研究人员证明了transformers具有足够的表达能力,可以在前向传播中实现许多其他策略评估算法,包括残差梯度、带有资格跟踪的TD和平均奖励TD。
上下文学习已经成为大型语言模型最显著的能力之一。在上下文学习中,模型的输入(即提示)包括上下文(即实例-标签对)和一个查询实例。然后,模型在推断期间(即前向传播)为查询实例输出一个标签。模型输入和输出的一个示例可以是:
其中,“5 → number; a → letter”是包含两个实例-标签对的上下文,“6”是查询实例。根据上下文,模型推断查询“6”的标签为“number”。值得注意的是,整个过程在模型的推断时间内完成,而不需要调整模型的参数。
在(1)中的示例说明了一个监督学习问题。在经典的机器学习框架中,这个监督学习问题通常通过首先基于上下文中的实例-标签对训练一个分类器来解决,使用诸如梯度下降之类的方法,然后要求分类器预测查询实例的标签。值得注意的是,研究表明,transformers能够在前向传播中实现这个梯度下降训练过程,而不需要调整任何参数,为上下文学习提供了一个可能的解释。
超越监督学习,智能涉及到顺序决策,其中强化学习已经成为一个成功的范式。transformers在推断期间能否执行上下文RL,以及如何执行?为了解决这些问题,研究人员从马尔可夫奖励过程MRP中的一个简单评估问题开始。在MRP中,代理程序在每个时间步中从一个状态转换到另一个状态。用(S0,S1,S2,...)表示代理访问的状态序列。在每个状态下,代理程序会接收到一个奖励。用(r(S0),r(S1),r(S2),...)表示代理程序在路途中接收到的奖励序列。评估问题是估计值函数v,该函数计算每个状态未来代理程序将收到的期望总(折扣)奖励。所需的输入输出的一个示例可以是:
引人注目的是,上述任务与监督学习根本不同,因为目标是预测值v(s),而不是即时奖励r(s)。此外,查询状态s是任意的,不必是S3。时间差分学习TD是解决这类评估问题(2)的最常用的RL算法。而且众所周知,TD不是梯度下降。
在这项工作中,研究人员做出了三个主要贡献。首先,通过构造证明transformers具有足够的表达能力来在前向传播中实现TD,这一现象我们称为上下文TD。换句话说,transformers能够通过上下文TD在推断时间内解决问题(2)。超越最直接的TD,transformers还可以实现许多其他策略评估算法,包括残差梯度(Baird,1995)、带有资格跟踪的TD(Sutton,1988)和平均奖励TD(Tsitsiklis和Roy,1999)。特别地,为了实现平均奖励TD,transformers需要使用多头注意力和过度参数化的提示,例如,
这里,“□”充当一个虚拟占位符,在推断期间transformers将使用它作为“记忆”。第二,通过在多个随机生成的评估问题上训练transformers与TD,实证地证明了在推断中出现了上下文TD。换句话说,学习的transformer参数与我们在证明中的构造非常相符。将这种训练方案称为多任务TD。第三,通过展示对于单层transformer,证明了实现上下文TD所需的transformer参数在多任务TD训练算法的不变集合的子集中,来弥合理论和实证结果之间的差距。
论文:https://arxiv.org/pdf/2405.13861
本文转载自公众号AIGC最前线
原文链接:https://mp.weixin.qq.com/s/voNZDTww7E5ec1hUwulztw