StaR | 用少量推理数据让模型学会通用推理能力,显著提升模型复杂推理
一、概述
•Title:STaR: Bootstrapping Reasoning With Reasoning
•URL: https://arxiv.org/abs/2203.14465
•Authors:Eric Zelikman, Yuhuai Wu, Jesse Mu, Noah D. Goodman
•Code: https://github.com/ezelikman/STaR
1 Motivation
•Step-by-step推理步骤生成可以提升语言模型在复杂推理任务(如数学或常识问答)上的性能,但是当前要让LLM能生成rationale推理过程,要么需要构建庞大的推理数据集,要么在只使用少量示例(但推理时牺牲了准确性)。
•需要一种方法来利用少量的推理示例和大量未经过推理的数据来提升模型的推理能力。
2 Methods
1 省流版总结:
- 使用少量推理示例(few-shot)引导语言模型生成多个问题的推理Rational过程。
- 对于模型生成的错误答案,通过提供正确答案(Hint)来生成新的推理过程(称为“rationalization”)。
- 在所有最终生成正确答案的推理上微调模型(Finetune)。
- 重复上述过程,直到performance不再提升(注意每次都使用original的预模型进行continually training来避免overfitting)。
2 专业版总结:
本文提出了一种名为“Self-Taught Reasoner”(STaR)的方法来解决语言模型在复杂推理任务上性能提升的问题。**STaR方法的核心思想是通过迭代地利用少量推理示例(rationales)和大量无推理数据集,逐步引导模型提升进行复杂推理的能力。**具体来说,STaR方法包括以下几个步骤:
- Rationale Generation Bootstrapping:首先,使用少量带有推理过程的示例作为提示,引导预训练的大型语言模型(LLM)生成多个问题的推理过程。这个过程被称为“rationale generation”。
- Filtering and Finetuning:接着,只保留那些生成了正确答案的推理过程,并在这些数据上对模型进行微调(finetune)。这一步骤的目的是强化模型生成高质量推理过程的能力。
- Rationalization:对于模型未能正确回答的问题,STaR采用一种称为“rationalization”的技术。在这个阶段,模型被提供正确答案作为提示,然后生成一个合理的推理过程来解释这个答案。这样做可以让模型从错误中学习,并改进其推理策略。
- Iterative Improvement:重复上述过程,每次都使用上一轮微调后的模型来生成新的训练数据。通过这种方式,模型逐渐学习如何更好地生成推理过程,并解决越来越复杂的问题。
- 5.Performance Evaluation:在每次迭代后,评估模型在测试集上的性能,直到性能达到饱和或不再显著提升。
3 Rationalization指的是什么?
Q1:为什么要用Rationalization?
• 直接让LLM生成推理思考过程,这些思考过程有些是对的,有些是错的,直接拿正确的思考过程,来训练llm生成rational,由于没有增量信息,会导致模型不能从failed example中学习,这样就不能让模型具备对new problems进行推理的能力。
Q2: 如何生成Rational
• 如下图所示,直接让LLM生成推理过程,对于failed的例子,加上label作为hint,基于hint,可以生成正确的推理过程。
3 Conclusion
• STaR显著提升了在多个数据集上的性能,相对于直接预测最终答案的模型,其效果更加突出。
• 在CommonsenseQA数据集上的表现与微调一个大30倍的最先进语言模型相当。
• STaR使得模型能够通过学习自身生成的推理步骤逐步提升推理能力。
二、详细内容
1 实验设计
数据集:
- 算术问题:使用随机生成的加法问题来测试STaR在处理数字运算任务上的性能。
- 常识问答(CommonsenseQA):使用CommonsenseQA(CQA)数据集,这是一个多项选择的常识推理任务,测试STaR在自然语言推理上的能力。
- 小学数学(Grade School Math, GSM8K):使用GSM8K数据集,包含小学水平的数学问题,这些问题以自然语言的形式表述,需要进行多步计算来得出答案。
Baseline:模型采用的是6B的开源模型(GPT-J),其checkpoint和fine-tuning code都开源了。
2 Rationalization能快速提升accuracy(从失败中学习能快速成长!!!)
说明;rationalization指的就是对于failed的example,加上hint,生成正确的推理过程数据并用于训练。
结论:随着STaR算法迭代次数的增加,模型在算术任务上的准确率逐渐提高。特别是在使用rationalization的情况下,准确率提升更加块。
3 STaR + rationalization比直接FT和few-shot效果好很多
• CQA数据集
• GSM8K数据集
说明:
• Direct Finetuned:不输出中间推理过程
• STaR without rationalization:不从失败样例中学习(以label作为hint生成推理过程用于ft)
• STaR with rationalization:从失败中学习
结论1:生成中间推理过程能显著提升最终的精度,例如就算使用100%的数据,不加推理过程,精度只能到60%,加上后用更少的数据却能更高的精度(大于68%)。
结论2:rationalization从失败中学习能进一步提升精度。
三、总结
STaR方法的关键在于,它允许模型通过自我生成的推理过程来自我改进,而不需要人工标注大量的推理数据集。此外,**通过rationalization技术,STaR能够确保模型从其错误中学习,从而提高整体的推理能力。**论文的实验结果表明,STaR在多个数据集上的性能显著优于直接预测答案的模型,并且与使用30倍更大模型的微调性能相当。
本文转载自NLP PaperWeekly,作者: NLP PaperWeekly