![](https://s5-media.51cto.com/aigc/pc/static/noavatar.gif)
新的训练范式可以防止机器学习模型学习虚假相关性 原创
记忆感知训练(MAT)这种训练范式通过修改模型逻辑以防止机器学习模型学习虚假相关性,提高了泛化能力,缩小了平均准确率(AVG)和最差群组准确率(WGA)之间的差距。
机器学习领域长期存在的问题之一是错误相关性的记忆。例如:假设开发人员正在开发一个深度神经网络对陆地鸟类和海洋鸟类的图像进行分类。他们使用数千张标记过的图像训练模型,这个模型在训练集和测试集上的表现都非常出色。然而,当向模型展示一张在兽医那里接受治疗的受伤海鸟的图片时,却错误地将其归类为陆地鸟类。
开发人员最初利用海鸟在海面飞翔的图像对模型进行训练,这导致了一个意外的结果:模型并没有有效学习到海鸟的独特特征,反而专注于识别图像中是否存在大片水域。因此,当向模型展示这张受伤海鸟的图片时,其模型错误地将其归类为陆地鸟类。
这是机器学习模型学习特征和标签之间的虚假相关性的一个例子。机器学习模型具有“惰性”,通常会选择捷径以实现目标。在这个例子中,模型在其训练数据中记住了错误的特征——海鸟图片中的水域。
记忆虚假相关性的主要缺点是缺乏泛化能力。该模型可能会给人一种进步的假象,但在现实情况下可能无法很好地发挥作用。鸟类分类就是一个很好的例子。但是,当机器学习模型用于医疗保健或自动驾驶汽车等关键应用时,虚假相关性可能会造成危害。
如何检测机器学习模型是否学习了虚假相关性?蒙特利尔大学和Meta公司的研究人员日前发表的一篇新论文探讨了机器学习模型中记忆的动态,以及它是如何导致学习虚假相关性的。他们还提出了一种名为“记忆感知训练”(MAT)的新范式,可以帮助防止机器学习模型在训练过程中学习虚假相关性。
ERM的问题
训练神经网络的标准方法是经验风险最小化(ERM),这是一种学习算法旨在致力减少模型在训练数据集上的损失。用于机器学习和深度学习的随机梯度下降(SGD)算法是一种解决ERM的优化算法。
ERM面临的一个问题是,它可能会促使模型快速捕捉虚假相关性,而不是深入理解并学习问题潜在分布的真正模式。当虚假相关性非常显著时(例如,海鸟示例中的大片水域),模型往往会提前停止学习,错过进一步挖掘真正有用模式(如图像中的鸟类特征)的机会。这会降低泛化能力,因为在实际情况下,虚假的特征可能并不存在,而有用的特征始终存在(例如,远离水域的海鸟)。
如果一个模型有足够的参数,它甚至会记住仅特定于单个数据点的独特特征,而这些特征并不适用于其他示例。这些特征与真正能够预测目标变量的核心属性无关。
图1 ERM会导致机器学习模型记忆虚假的特征,无法推广到少数示例
为了验证模型是否学习了虚假相关性,必须在包含少数示例的保留样本上进行评估,这些示例不符合神经网络从大多数训练数据中学习的简单解释。例如对奶牛和骆驼的图像进行分类的一个模型,如果训练集中的大多数奶牛出现在草地上,大多数骆驼出现在沙地上,那么沙地上的奶牛或草地上的骆驼就是少数示例。
记忆感知训练(MAT)
虽然给出的例子可以帮助发现记忆虚假相关性的迹象,但该论文提出了一种使用少数示例来指导模型学习可推广模式的方法。
这种方法称为记忆感知训练(MAT),通过使用预测来修改模型的逻辑——神经网络在转换为概率之前输出的原始预测。
图2 记忆感知训练(MAT)防止机器学习模型学习虚假相关性,并迫使其对少数示例进行泛化
具体来说,MAT通过引入基于“校准保留概率”的每个示例的逻辑移位来修改ERM目标。这里的校准保留概率旨在通过一种机制,增加那些预测错误且保留概率较高示例的损失,同时降低那些预测正确且保留概率也较高的示例的损失,从而调整训练重点。通过将这些概率添加到损失函数中,训练算法可以防止模型记忆虚假相关性,并优先学习数量较少或难以分类的示例,这些示例的泛化能力通常较差。
为了计算保留概率,MAT使用了一个通过交叉风险最小化(XRM)训练的辅助模型。XRM是一种训练技术,旨在通过在两个网络上对训练数据的随机一半进行训练来发现数据集内的不同环境。关键思想是鼓励每个网络学习一个有偏见的分类器,然后使用一个模型对另一个模型的数据所犯的错误(交叉错误)来注释训练和验证示例。
为了跟踪MAT的有效性,可以比较训练模型的平均准确率和最差群组准确率(WGA)之间的差异 (WGA衡量模型在表现最差的子组上的准确率。这是评估模型稳健性的关键指标,特别是在处理虚假关联性和不平衡数据集时)。
图3 通过缩小平均准确率(AVG)和最差群组准确率(WGA)之间的差距,MAT具有更好的泛化能力
在传统的训练方法中,AVG与WGA之间的差距可能很大。而在MTA中,这一差距减小了(尽管以损失一小部分平均准确率为代价),从而更真实地反映了模型的性能。
尽管大型语言模型(LLM)等领域的发展备受业界瞩目,但机器学习基础领域的持续探索令人耳目一新。MAT等技术对于现实世界的机器学习应用至关重要,因为开发人员希望机器学习模型在这些应用中能够应对各种复杂多变的场景。
原文标题:New training paradigm prevents machine learning models from learning spurious correlations,作者:Ben Dickson
![](https://s5-media.51cto.com/aigc/pc/static/noavatar.gif)