上海交大、复旦、上海 AI Lab引入渐进学习框架来验证弱到强的推理

发布于 2024-7-22 10:21
浏览
0收藏

​一、结论写在前面

论文标题:Weak-to-Strong Reasoning

论文链接:​​https://arxiv.org/pdf/2407.13647​

代码等:​​https://github.com/GAIR-NLP/weak-to-strong-reasoning​

当大型语言模型 (LLMs) 超越人类水平能力时,为这些模型提供全面且准确的监督变得愈发困难。弱到强学习,即利用能力较弱的模型来解锁更强大模型的潜在能力,在此背景下被证明是有价值的。然而,这种方法在复杂推理任务中的有效性仍未得到验证。此外,在弱到强设置下解决推理任务目前缺乏有效方法来避免盲目模仿弱监督者及其错误。

本文探讨了弱到强框架在复杂推理任务中的效能。论文引入了一种新方法,该方法利用弱监督激发强大能力,无需依赖人类标注或更高级模型的注释。该方法侧重于强模型自主精炼其训练数据的能力,即使它之前未曾学习过该任务。通过迭代扩展其学习范围,强模型不断拓宽其推理技能。这种自我导向的数据治理对于扩大AI推理能力提升的规模至关重要,使模型在其发展轨迹中更加独立和高效。

论文使用Llama2-70b作为强模型,测试了三个独立的弱模型:Llama2-7b、Gemma-2b和Mistral-7b,并在常用的数学推理数据集GSM8K和MATH上进行实验。实验结果显示:

1.完全弱微调虽然在分类任务中有效,但在复杂推理任务中表现不佳。

2.论文提出的方法显著优于完全弱微调方法,在第一阶段训练(M → Mplus)后,仅由弱模型(即Gemma-2b)监督时,在GSM8K上实现了26.99点的改进,并通过偏好优化(Mplus → Mpro)进一步提高了8.49点的性能,而无需知道金标准答案。    

3.论文提出的偏好优化阶段使强模型能够从弱监督者的错误中学习,最终在具有挑战性的场景(如4-5级MATH问题)中超越了在金标准解决方案上微调的强模型(即强上限)。

为更准确地模拟未来场景,论文在OlympicArena上进行了额外的实验,这是一个极具挑战性的数据集,没有明确的标准答案。尽管规模较小,但Llama3-8binstruct(AI@Meta,2024)已经经过对齐,并被证明可以有效地监督更大的Llama3-70b,后者的潜力尚未被充分发挥。此外,论文提出的两阶段训练方法比完全弱微调高出3.19点。

上海交大、复旦、上海 AI Lab引入渐进学习框架来验证弱到强的推理-AI.x社区

图1:( a ):使用 Llama2-7b 监督 Llama2-70b 在 GSM8K 上的测试准确率。(b):使用 Llama3-8b-instruct 监督 Llama3-70b 在 OlympicArena 上的测试准确率。"弱基础" 指的是弱模型的结果。"全弱微调" 指的是基线结果,其中强模型在弱模型生成的完整数据集上进行简单微调。"论文的阶段I" 表示使用论文提出的弱到强方法进行监督微调的第一阶段结果。请注意,论文的方法在阶段I产生了三种增强的强模型变体,论文在这里展示最佳结果。"论文的阶段II" 表示使用论文的方法进行偏好优化的第二阶段结果

二、论文的简单介绍

2.1 论文的背景

"学生不必不如老师;老师不必比学生更聪明。" ——《On Teachers》

随着人工通用智能(AGI)研究的推进,创造超越人类认知能力的超智能系统一直是该领域的一个关键目标)。这一追求带来了一系列挑战,尤其是在这些高级AI模型的监督和学习范式方面。传统的监督方法通常依赖于人类监督或来自更高级模型的指导(即知识蒸馏,distilled knowledge)),但当AI的能力超越其监督者时,这些方法变得不足。

为解决这个问题,论文关注弱到强学习范式(weak-tostrong learning paradigm),该范式在一个独特的任务设置下运作,即只有一个能力较弱的模型和一个更强大但未充分利用的模型可用。弱到强学习的核心问题是,能力有限的模型是否能有效指导更先进、更强大模型的发展。Burns等人(2023)的先前研究已经证明了这种方法在分类、国际象棋和奖励建模任务中的可行性。然而,这种设置是否适用于更复杂的推理任务仍是一个开放性问题,这些任务需要的不仅仅是简单的外推或模式识别。

复杂推理是人类认知的一个关键方面,对于评估大语言模型是否能模仿或超越人类理解世界、做出决策和解决问题的能力至关重要。鉴于这些任务的复杂性和关键性,将弱到强学习框架应用于高级推理挑战是至关重要的,特别是在实现超智能的更广泛背景下。    

上海交大、复旦、上海 AI Lab引入渐进学习框架来验证弱到强的推理-AI.x社区

尽管Burns等人(2023)建议,在弱模型产生的全部噪声数据上简单地微调强模型(称为完全弱微调)可以持续提高其性能超过较弱的对应模型,但这种方法仍远未恢复强模型的全部能力,而且论文的实验表明,在面对更复杂的推理挑战时,它失去了效果。他们还提出了一种辅助置信度损失,以缓解强模型模仿其监督者错误的问题。然而,这种方法是为具有一组固定标签的分类任务量身定制的,不能自然地扩展到包括推理在内的开放式生成任务。目前,在弱到强推理框架内,除了简单的微调之外,缺乏有效的方法来防止过度拟合弱错误并进一步激发强模型的内在推理能力。    

为实现上述目标,论文引入了一个渐进式改进学习框架,遵循的原则是模型可以通过最初关注较小、更可靠的数据子集,然后逐步扩大学习范围来更有效地提高其能力,如图2所示:

•在第一阶段,论文假设利用可能更准确的较小数量的数据更有利。论文通过结合弱模型生成的数据和更高级模型通过上下文学习自生成的数据来实现这一点。然后将这种混合用于有选择地策划后续监督微调的数据集。

•在第二阶段,在开发出具有改进推理能力的强模型后,论文利用其构建对比样本进行偏好优化的能力,使模型能够有效地从较弱模型的错误中学习。

2.2 预备知识

2.2.1 大语言模型的典型学习范式

论文概述了大型模型训练中的常见学习范式,主要特征是需要标准答案和来自更强大模型的监督,如表1所示。

通用监督学习 在训练大语言模型时,理想情况是拥有足够数量的带有标准答案的训练数据,论文称之为通用监督学习范式。然而,获取这样的数据往往需要大量的标注工作,有时甚至是不可能的。因此,各种学习范式应运而生,以减少数据质量和数量的影响,同时仍能提高性能。

上海交大、复旦、上海 AI Lab引入渐进学习框架来验证弱到强的推理-AI.x社区

    

表1:LLMs的典型学习范式。“V”和“X”表示是否需要监督,“-”表示可选。“G.T:”代表真实答案

基于蒸馏的学习 在当前背景下,即使没有标准答案,要提升像Llama2-70b这样的强大模型,仍可以通过寻求像GPT-4这样更强大的模型的帮助来实现改进。因此,许多现有工作建议让一个更强大的模型充当教师模型,为目标模型提供具体反馈以改进。这种范式可以被视为蒸馏更强大教师模型的知识。然而,仅仅模仿教师模型并非长期解决方案;在模仿数据中未充分代表的任务上,模仿模型只能略微缩小与教师模型的性能差距。此外,蒸馏学习主要有益于那些不如教师模型能力强的模型。

自我改进学习 考虑到由人类或更强大的专有模型标注训练数据的高昂成本,一系列工作依赖于模型自身生成的正确响应来更新它。例如,Zelikman等人(2022)、Yuan等人(2023)、Singh等人(2023)、Hosseini等人(2024)根据最终答案的正确性筛选解决方案,而Lightman等人(2023)、Lin等人(2024)则使用在金标准标注上训练的奖励模型来评分自生成内容。显然,无论是使用二元标签还是细粒度反馈,这种范式仍然需要标准答案来评估模型自生成响应的可用性。没有标准答案,自我改进只能带来最小的性能提升,甚至可能降低性能。

半监督学习 从传统机器学习领域的半监督学习中获得启发,另一种大语言模型学习不依赖于大量标注,而是依赖于一个小型的高质量种子数据集。Tong等人(2024)通过学习自生成响应与专家标注响应之间的差异,展示了改进。论文还将当前流行的研究主题——易到难泛化纳入这一类别,其中模型通过学习人类对较简单任务的标注来解决复杂任务。这一系列研究不可避免地需要获取一小部分高质量的标准答案。

弱到强学习 在模型超越人类能力的场景中,为复杂任务提供全面和精确监督的挑战变得更加严峻,特别是在没有标准答案,也没有更高级模型提供监督指导的情况下。这种缺失凸显了弱到强学习方法的关键重要性。这些方法独特地利用较弱的监督信号来恢复已经强大的模型中的潜在知识。例如,用GPT-2级别的监督者对GPT-4进行微调,可以在某些任务上恢复接近GPT-3.5级别的性能。这一策略对推动人类社会进步具有深远意义,它使大语言模型具备解决当前无法解决的数学和物理挑战的能力。与其他学习范式不同,弱到强学习在相对宽松的条件下运作,为探索和创新开辟了广阔的机会。    

2.2.2 弱到强推理设置

论文在弱到强的设置下处理推理任务,如表2所示。首先,论文研究数学推理任务,如GSM8k和MATH中的任务。这些任务要求推理过程的每一步都展示基本的数学问题解决技能,包括问题理解和代数运算,并在前几步的基础上继续推进。这对模型的学习和泛化能力提出了更高的要求。与分类任务不同,模型可以依赖于表面模式的外推或识别,而推理任务几乎无法从猜测中获益。

然后,论文使用一个具有一定数学问题解决能力的弱模型(例如Llama2-7b),记为m。这个模型类似于超智能时代中具有有限专业知识的人类监督者。此外,论文只有一组没有标准答案的问题Q = {qi,目标是提高强模型M(例如Llama2-70b)的推理能力。

为了实现这一点,论文遵循Burns等人(2023)的方法,将原始训练集随机分成两个相等的部分,Dgold,1和Dgold,2。弱模型最初使用Dgold,1进行微调,其中有可用的标准解决方案,从而得到一个具有一定问题解决能力的弱模型,即m。相比之下,强模型只能访问来自Dgold,2的问题,没有推理链或最终答案,即Q。

2.3 方法论

在本节中,论文提出了一种弱到强的训练方法,旨在最大限度地利用弱数据并激发强模型的内在潜力。首先,在没有标准答案和外部信号的情况下,论文识别出潜在的正样本。在第一阶段,论文仅利用这部分数据进行监督式微调。然后,一旦强模型达到了一定的推理水平,论文就在第二阶段使用全部弱数据,特别是通过基于偏好学习的方法(如 DPO,)来利用潜在的负样本,鼓励强模型从弱模型的错误中学习。整个框架如图 3 所示。

2.3.1 阶段I:从“正样本”中学习

给定一个弱模型m 和一系列没有真实标签的数学问题Q,m 生成弱数据D_weak = {q_i, C_weak,i, a_weak,i },其中q_i ∈ Q,C_weak,i 表示推理链,a_weak,i 表示最终答案。a_weak,i 的正确性是未知的。核心挑战在于:论文如何最大化利用m 和D_weak 来充分增强和恢复一个更强模型M 的数学推理能力?    

2.3.1.1 全面弱数据微调

论文的初始策略是对更强模型M 在整个弱数据集Dweak 上进行微调。尽管先前研究(Burns et al., 2023)已验证了这种方法在文本分类任务中的有效性,但其在推理任务中的效果尚未探索。因此,论文着手研究弱到强泛化现象是否也能在此较少探讨的领域增强M 的推理能力。

2.3.1.2 弱上下文学习

另一种直接的方法是上下文学习(ICL, in-context learning),它仅需要几个训练样本作为提示中的演示。具体来说,论文从D_weak 中随机选择四个样本作为演示。由于论文无法访问真实标签,这些演示不能被证明是正确的。      

上海交大、复旦、上海 AI Lab引入渐进学习框架来验证弱到强的推理-AI.x社区

图3:论文的方法概览,从M 演进为Mplus 再到Mpro。左侧:论文利用最终答案一致性来有选择地从多样化的来源中过滤弱数据和ICL数据,这些数据用于微调强模型M 并获得具有增强数学推理能力的Mplus。右侧:论文利用Mplus 的置信度来识别对比样本以进行性能优化,从而得到更稳健的强模型Mpro。    

2.3.1.3 弱-ICL微调

鉴于模型可以通过监督微调模仿弱错误,论文建议在使用前对Dweak进行精炼,而不是盲目使用所有数据。此外,论文还寻求利用通过上下文学习激活的强模型的固有能力。基于这两个想法,论文引入了弱-icl微调,同时使用弱数据D_weak和"icl数据"D_icl = {q_i, c_icl,i, a_icl,i},其中qi ∈ Q,c_icl,i和a_icl,i是由M通过少样本示例生成的,作为更高质量的监督信号。需要注意的是,对于D_weak和D_icl,论文无法确定某个答案是否正确。

尽管如此,当两个采用不同数据表示的模型在开放式任务中得出相同答案时,这表明准确性的可能性更高。这种现象支持了在不同方法之间观察到一致性时结果的可靠性。因此,论文比较由弱模型和强模型分别生成的D_weak和D_icl,并在a_weak,i = a_icl,i时选择Dˆweak和Dˆicl用于后续的监督微调。论文称这种方法为最终答案一致性。考虑到这两组数据的组合,论文可以得到三个版本的增强微调强模型:

•M_weak-ft:在Dˆweak上微调的M。

•M_icl-ft:在Dˆicl上微调的M。

•M_hybrid-ft:在Dˆweak和Dˆicl的并集上微调的M。

迭代训练 仔细观察M_weak-ft和M_icl-ft,论文发现它们仍然满足具有不同数据表示的条件,因为它们是在来自不同来源的数据上训练的——Dˆweak由弱模型生成,而Dˆicl主要源自强模型本身。因此,论文可以进行迭代训练以提升性能。论文将初始轮次的监督微调数据表示为Dˆ1weak和Dˆ1icl,得到模型M1weak-ft、M1icl-ft和M1hybrid-ft。在第二次迭代中,论文将M1weak-ft应用于Q以构建D2weak,将M1icl-ft应用于构建D2icl。这里,下标"weak"和"icl"表示初始数据来源。然后论文应用最终答案一致性来获得Dˆ2weak和Dˆ2icl。经过另一轮监督微调后,论文得到:

•M2weak-ft:在Dˆ2weak上微调的M。

•M2icl-ft:在Dˆ2icl上微调的M。

•M2hybrid-ft:在Dˆ2weak和Dˆ2icl的并集上微调的M。    

上海交大、复旦、上海 AI Lab引入渐进学习框架来验证弱到强的推理-AI.x社区

需要注意的是,迭代训练步骤是可选的;当数据质量过低或模型过拟合时,可能会导致性能下降。

2.3.2 第三阶段:从“负面”样本中学习

论文将第一阶段的最终迭代模型表示为 Mplus,该模型已学习了双重数学解决方案,并具有进一步增强的潜力。接下来,论文应用偏好优化技术,战略性地利用由m 生成的原始弱数据集Dweak={q_i, c_weak, a_weak,i}中的潜在错误子集,使得强模型能够识别并避免在未来的推理过程中出现类似的错误。关键在于如何构建用于学习的对比样本。

在没有访问真实答案的情况下,当前具备增强推理能力的强大模型会根据其置信度识别最可能正确的答案。具体而言,对于每个问题q_i 属于 Q,论文从模型Mplus 中抽取n 个回答,并将这些回答中出现频率最高的答案的概率定义为置信度。当置信度低于指定阈值τ 时,论文认为模型对这一问题的判断不可靠,因此将其舍弃。相反,如果置信度不低于τ,论文则认为模型能够解答该问题,并继续构建对比样本,具体步骤如下:

上海交大、复旦、上海 AI Lab引入渐进学习框架来验证弱到强的推理-AI.x社区

进一步在样本上训练M_plus使其能够区分正确与错误的解决方案,从而得到一个更强的模型M_pro。

2.4 实验

2.4.1 数据集

GSM8K和 MATH是两个广泛使用的数学推理数据集,其中 MATH 包含更具挑战性的竞赛问题。论文使用的数据统计信息如表 4 所示。特别是,为了确保弱模型有足够的训练数据来培养初步的数学技能,论文通过 Chern 等人(2023)构建的数据增强了 GSM8K 训练集。

上海交大、复旦、上海 AI Lab引入渐进学习框架来验证弱到强的推理-AI.x社区

表 4:数据统计。Dg o l d, 1 和Dg o l d, 2 是训练集的子集。弱模型使用Dg o l d, 1 来培养初始数学技能,而强模型只能访问Dg o l d, 2 中的问题,没有正确答案    

上海交大、复旦、上海 AI Lab引入渐进学习框架来验证弱到强的推理-AI.x社区

图4:第一阶段的主要结果。第0^m 轮展示了两个基线的性能,其中“weak”表示完全弱微调,即在全部弱数据上进行简单微调,“icl”指的是不进行微调的弱ICL。连线表示模型共享相同的训练数据源。低于“强上限”的结果显示了通过贪婪解码的测试准确率,而高于“强上限”的结果显示了pass@k分数( k=10 和温度=1.0 )。为简洁起见,论文仅展示了通过贪婪解码超越的Mhybrid-tad 检查点的pass@k分数,完整结果在A.4.2 中提供

2.4.2实验设置

论文使用Llama2-70b作为强模型,并采用来自不同家族的三种弱模型:Llama2-7b、Gemma-2b和Mistral-7b。论文对弱模型在D_gold,1 上进行全参数微调,并一致采用LoRA对强模型进行微调。在第一阶段,论文根据迭代原则,在GSM8K上进行两轮迭代,在MATH上进行一轮迭代。在第二阶段,论文采用基于偏好学习的两种方法,DPO及其变体ORPO。

论文在测试集上评估准确性。弱模型m 的性能与通过Dgold,2 中的黄金解决方案数据微调的强模型M 的性能相结合,代表了强模型与弱模型结合的最佳性能。    

2.4.3 第一阶段结果

GSM8K和MATH数据集上第一阶段的主要结果如图4所示。值得注意的是,在MATH实验中,由于可用数据量较小,论文随机抽取了未根据最终答案一致性选择的数据。根据图4,论文有以下观察结果。

弱ICL微调显示出显著提升。使用论文提出的方法,仅由在GSM8K上准确率为25.17 的弱Gemma-2b监督的强模型性能,可以提升至60.12,超过简单全弱微调31.08,并且超过Mplus(即Mhybrid-ft^2)。随着弱模型的改进,这一结论在分类任务上得到了Liu和Alahi(2024)的验证。具体而言,GSM8K上的性能逐渐提升,从Gemma-2b到Llama-7,再到Mistral-7b(25.17 -> 33.81 -> 59.51)。因此,通过这些模型生成的数据训练的强模型的最大性能也逐步提升(60.12 -> 63.76 -> 68.39)。

Mhybrid-rt 实现了最高的 pass@k 分数。正如预期,Mhybrid-t 在弱到强设置中取得了最高的 pass@k 分数,这得益于其训练数据融合了两种类型的解决方案——一种来自弱模型,另一种来自强模型。这种多样性通过降低过拟合的可能性增强了模型的鲁棒性。此外,Mia-t 的表现通常优于 Mweak-ft,这可以归因于过程级精度的变化和可能的解决方案格式。

简单的微调不足以应对弱到强的推理任务。当使用 Gemma-2b 作为 MATH 数据集上的弱模型时,完全弱微调的表现不如弱基准(10.0 对比 11.6)。这表明,尽管简单的微调在分类、国际象棋和奖励建模任务中成功应用(Burns et al., 2023),但对于复杂的推理任务,尤其是像 MATH 中的高难度问题,这种方法显得力不从心。相比之下,论文的弱-icl 微调方法有效地弥合了这一差距,为弱到强推理挑战提供了一种有效且可扩展的解决方案。

ICL性能的影响 考虑到弱-icl微调的有效性部分取决于弱ICL的效果,论文进一步探讨了通过谨慎选择示例来增强ICL性能如何影响弱-icl微调的表现。图5展示了使用Gemma-2b作为弱模型,在不同示例集下GSM8K测试的准确率。结果表明,使用这组特定示例的弱ICL性能从原始的56.48提高到了64.06。

随后,论文在提示中使用这些示例重新生成Dicl,并在通过最终答案一致性精选的Dˆicl上微调强模型。这进一步将性能从64.06提升到64.75,证实了自主数据筛选的有效性。    

值得注意的是,尽管弱ICL具有高性能的潜力,但在弱到强的框架中选择有效的示例并非易事,这超出了本文的讨论范围。

上海交大、复旦、上海 AI Lab引入渐进学习框架来验证弱到强的推理-AI.x社区

2.4.4 第二阶段结果

论文采用Mhybrid-ft的最终迭代作为Mplus进行后续的偏好学习。实验结果验证了该检查点达到了更高的pass@k,并具有进一步提升的显著潜力。

如表5所示,论文构建正负样本的方法有效地增强了强模型的数学推理能力。在GSM8K上,DPO和ORPO使用论文构建的数据集都持续取得显著改进,特别是在由Gemma-2b监督时,增加了8.49个百分点。尽管MATH问题本质上具有挑战性,这影响了强模型的判断并在训练数据中引入了不准确性,但论文的方法通过ORPO仍然在MATH上取得了至少1个百分点的改进。    

上海交大、复旦、上海 AI Lab引入渐进学习框架来验证弱到强的推理-AI.x社区

数据构建方法 在构建偏好数据时,论文始终使用由弱模型生成的弱响应作为被选择/拒绝的响应之一,而不是完全依赖自生成的数据。论文还在GSM8K上使用Llama2-7b作为弱模型测试了自生成设置,其中被选择和被拒绝的响应都由强模型自身生成。在这种设置下,DPO测试准确率为62.40(-0.22),表明性能略有下降。在没有真实标签的情况下,构建的正负样本实际上分别对应于更频繁和较少出现的答案,并与模型倾向于选择的答案相关。由于偏好优化本质上执行排序,这种自生成设置的潜在收益是最小的。因此,在偏好数据构建过程中引入弱数据信号被证明是一种更好的方法。

2.4.5 分析

为进行进一步分析,论文检查了MATH测试集中不同难度级别的准确率。

如图6所示,强模型在较简单的问题上表现出更好的泛化能力。具体来说,尽管Llama2-7b在1级问题上只达到了6.98点的准确率,但Llama2-70b在使用这种弱监督训练后,可以在1级问题上达到超过30点的准确率。对于更具挑战性的问题(4-5级),经ORPO增强的Mpro甚至超过了仅通过金标准解决方案监督微调获得的强模型上限。这一现象验证了从不正确数据中学习的有效性。    

上海交大、复旦、上海 AI Lab引入渐进学习框架来验证弱到强的推理-AI.x社区

2.4.6 更接近未来场景的实验

在对Llama3-70b(AI@Meta,2024)的初步测试中,论文观察到在GSM8K和MATH上,Llama3-70b可以通过上下文学习在很大程度上释放其潜力,而参数更新由于训练不稳定性而产生边际甚至负面影响。因此,论文聚焦于Llama3-70b发布后开发的更具挑战性的数据集OlympicArena,以模拟更真实的未来场景。

论文仅考虑OlympicArena中的英语问题,排除了需要基于案例或专家评估的CODE(代码生成)和OT(其他)问题类型。这样得到了6,020个没有解决方案和最终答案的训练数据,以及313个有最终答案的测试数据,用于评估不同方法的性能。论文使用Llama3-8b-instruct(未在训练数据子集上进行初始微调)作为弱模型,Llama3-70b作为待改进的强模型。超参数与GSM8K中使用的一致。这种配置更接近未来真实世界的弱到强场景。    

上海交大、复旦、上海 AI Lab引入渐进学习框架来验证弱到强的推理-AI.x社区

实验结果如表6所示。"Weak Floor"代表Llama3-8b-instruct的零样本性能,"Full Weak FT"表示Llama3-70b在训练集上由Llama3-8b-instruct生成的全部(即6,020个)弱解决方案上监督微调后的性能,"Weak ICL"表示Llama3-70b在Llama3-8b-instruct生成的4-shot弱示例下的性能。尽管参数更多,但由于挖掘能力不足,Llama3-70b在上下文学习下的表现仍低于Llama3-8b-instruct的零样本性能。

通过论文提出的弱-icl微调方法获得的M1 weak-ft,以更少的训练数据(即746个)达到了比Full Weak FT更高的性能,超过了0.32个百分点。经过第二阶段的偏好优化,进一步利用弱模型和没有答案的训练问题,强模型的性能比Full Weak FT又提高了3.19个百分点。这证明了论文的方法在更接近未来条件的场景中的稳健性和泛化能力。

本文转载自 AI帝国​,作者: 无影寺

收藏
回复
举报
回复
相关推荐