再掀强化学习变革!DeepMind提出「算法蒸馏」:可探索的预训练强化学习Transformer

人工智能 新闻
该怎么把预训练Transformer范式用到强化学习里?

在当下的序列建模任务上,Transformer可谓是最强大的神经网络架构,并且经过预训练的Transformer模型可以将prompt作为条件或上下文学习(in-context learning)适应不同的下游任务。

大型预训练Transformer模型的泛化能力已经在多个领域得到验证,如文本补全、语言理解、图像生成等等。

图片

从去年开始,已经有相关工作证明,通过将离线强化学习(offline RL)视为一个序列预测问题,那么模型就可以从离线数据中学习策略

但目前的方法要么是从不包含学习的数据中学习策略(如通过蒸馏固定的专家策略),要么是从包含学习的数据(如智能体的重放缓冲区)中学习,但由于其context太小,以至于无法捕捉到策略提升。

图片

DeepMind的研究人员通过观察发现,原则上强化学习算法训练中学习的顺序性(sequential nature)可以将强化学习过程本身建模为一个「因果序列预测问题」

具体来说,如果一个Transformer的上下文足够长到可以包含由于学习更新而产生的策略改进,那它应该不仅能够表示一个固定的策略,而且能够通过关注之前episodes的状态、行动和奖励表示为一个策略提升算子(policy improvement operator)。

这也提供了一种技术上的可行性,即任何RL算法都可以通过模仿学习蒸馏成一个足够强大的序列模型,并将其转化为一个in-context RL算法。

基于此,DeepMind提出了算法蒸馏(Algorithm Distillation, AD) ,通过建立因果序列模型将强化学习算法提取到神经网络中。

图片

论文链接:​https://arxiv.org/pdf/2210.14215.pdf​

算法蒸馏将学习强化学习视为一个跨episode的序列预测问题,通过源RL算法生成一个学习历史数据集,然后根据学习历史作为上下文,通过自回归预测行为来训练因果Transformer。

与蒸馏后学习(post-learning)或专家序列的序列策略预测结构不同,AD能够在不更新其网络参数的情况下完全在上下文中改进其策略。

  • Transfomer收集自己的数据,并在新任务上最大化奖励;
  • 无需prompting或微调;
  • 在权重冻结的情况下,Transformer可探索、利用和最大化上下文的返回(return)!诸如Gato类的专家蒸馏(Expert Distillation)方法无法探索,也无法最大化返回。

实验结果证明了AD可以在稀疏奖励、组合任务结构和基于像素观察的各种环境中进行强化学习,并且AD学习的数据效率(data-efficient)比生成源数据的RL算法更高。

AD也是第一个通过对具有模仿损失(imitation loss)的离线数据进行序列建模来展示in-context强化学习的方法。

算法蒸馏

2021年,有研究人员首先发现Transformer可以通过模仿学习从离线RL数据中学习单任务策略,随后又被扩展为可以在同域和跨域设置中提取多任务策略。

这些工作为提取通用的多任务策略提出了一个很有前景的范式:首先收集大量不同的环境互动数据集,然后通过序列建模从数据中提取一个策略。

把通过模仿学习从离线RL数据中学习策略的方法也称之为离线策略蒸馏,或者简称为策略蒸馏(Policy Distillation, PD)

尽管PD的思路非常简单,并且十分易于扩展,但PD有一个重大的缺陷:生成的策略并没有从与环境的额外互动中得到提升。

例如,MultiGame Decision Transformer(MGDT)学习了一个可以玩大量Atari游戏的返回条件策略,而Gato通过上下文推断任务,学习了一个在不同环境中解决任务的策略,但这两种方法都不能通过试错来改进其策略。

MGDT通过微调模型的权重使变压器适应新的任务,而Gato则需要专家的示范提示才能适应新的任务。

简而言之,Policy Distillation方法学习政策而非强化学习算法。

研究人员假设Policy Distillation不能通过试错来改进的原因是,它在没有显示学习进展的数据上进行训练。

算法蒸馏(AD)通过优化一个RL算法的学习历史上的因果序列预测损失来学习内涵式策略改进算子的方法。

图片

AD包括两个组成部分

1、通过保存一个RL算法在许多单独任务上的训练历史,生成一个大型的多任务数据集;

2、将Transformer使用前面的学习历史作为其背景对行动进行因果建模。

由于策略在源RL算法的整个训练过程中不断改进,AD必须得学习如何改进算子,才能准确模拟训练历史中任何给定点的行动。

最重要的是,Transformer的上下文大小必须足够大(即跨周期),以捕捉训练数据的改进。

图片

在实验部分,为了探索AD在in-context RL能力上的优势,研究人员把重点放在预训练后不能通过zero-shot 泛化解决的环境上,即要求每个环境支持多种任务,且模型无法轻易地从观察中推断出任务的解决方案。同时episodes需要足够短以便可以训练跨episode的因果Transformer。

图片

在四个环境Adversarial Bandit、Dark Room、Dark Key-to-Door、DMLab Watermaze的实验结果中可以看到,通过模仿基于梯度的RL算法,使用具有足够大上下文的因果Transformer,AD可以完全在上下文中强化学习新任务。

图片

AD能够进行in-context中的探索、时间上的信用分配和泛化,AD学习的算法比产生Transformer训练的源数据的算法更有数据效率。

PPT讲解

为了方便论文理解,论文的一作Michael Laskin在推特上发表了一份ppt讲解。

图片

算法蒸馏的实验表明,Transformer可以通过试错自主改善模型,并且不用更新权重,无需提示、也无需微调。单个Transformer可以收集自己的数据,并在新任务上将奖励最大化。

尽管目前已经有很多成功的模型展示了Transformer如何在上下文中学习,但Transformer还没有被证明可以在上下文中强化学习。

为了适应新的任务,开发者要么需要手动指定一个提示,要么需要调整模型。

如果Transformer可以适应强化学习,做到开箱即用岂不美哉?

但Decision Transformers或者Gato只能从离线数据中学习策略,无法通过反复实验自动改进。

图片

使用算法蒸馏(AD)的预训练方法生成的Transformer可以在上下文中强化学习。

图片

首先训练一个强化学习算法的多个副本来解决不同的任务和保存学习历史。

图片

一旦收集完学习历史的数据集,就可以训练一个Transformer来预测之前的学习历史的行动。

由于策略在历史上有所改进,因此准确地预测行动将会迫使Transformer对策略提升进行建模。

图片

整个过程就是这么简单,Transformer只是通过模仿动作来训练,没有像常见的强化学习模型所用的Q值,没有长的操作-动作-奖励序列,也没有像 DTs 那样的返回条件。

在上下文中,强化学习没有额外开销,然后通过观察 AD 是否能最大化新任务的奖励来评估模型。

Transformer探索、利用、并最大化返回在上下文时,它的权重是冻结的!

另一方面,专家蒸馏(最类似于Gato)不能探索,也不能最大化回报。

图片

AD 可以提取任何 RL 算法,研究人员尝试了 UCB、DQNA2C,一个有趣的发现是,在上下文 RL 算法学习中,AD更有数据效率。

图片

用户还可以输入prompt和次优的demo,模型会自动进行策略提升,直到获得最优解!

而专家蒸馏ED只能维持次优的demo表现。

图片

只有当Transformer的上下文足够长,跨越多个episode时,上下文RL才会出现。

AD需要一个足够长的历史,以进行有效的模型改进和identify任务。

图片

通过实验,研究人员得出以下结论:

  • Transformer可以在上下文中进行 RL
  • 带 AD 的上下文 RL 算法比基于梯度的源 RL 算法更有效
  • AD提升了次优策略
  • in-context强化学习产生于长上下文的模仿学习
责任编辑:张燕妮 来源: 新智元
相关推荐

2022-10-08 09:53:17

AI算法

2021-09-10 16:31:56

人工智能机器学习技术

2024-12-09 08:45:00

模型AI

2023-03-09 08:00:00

强化学习机器学习围棋

2020-08-10 06:36:21

强化学习代码深度学习

2023-06-25 11:30:47

可视化

2022-10-28 15:08:30

DeepMind数据

2023-11-07 07:13:31

推荐系统多任务学习

2021-09-17 15:54:41

深度学习机器学习人工智能

2022-11-03 14:13:52

强化学习方法

2020-11-12 19:31:41

强化学习人工智能机器学习

2024-10-12 17:14:12

2017-03-28 10:15:07

2023-09-21 10:29:01

AI模型

2020-02-21 15:33:44

人工智能机器学习技术

2024-04-03 07:56:50

推荐系统多任务推荐

2023-07-20 15:18:42

2020-06-05 08:09:01

Python强化学习框架

2023-01-24 17:03:13

强化学习算法机器人人工智能

2020-12-02 13:24:07

强化学习算法
点赞
收藏

51CTO技术栈公众号