模型知识蒸馏新SOTA!告别传统散度蒸馏|腾讯优图&中科大出品

人工智能 新闻
研究引入了SinKD以解决现有蒸馏方法的局限性。此外,作者们提出了基于批次的重构方法,以捕捉高维空间中样本分布的几何复杂性。

用大模型“蒸馏”小模型,有新招了!

甚至能在不同类型和架构的LLMs(大语言模型)上达到新SOTA。

这就是来自中科大、腾讯优图实验室提出的一种基于Sinkhorn距离的知识蒸馏方法,能把大的、复杂的教师模型的知识“蒸馏”到小的、简单的学生模型中,从而让小模型也能像大模型一样工作。

图片

之所以提出新方法,主要是现有的知识蒸馏(KD)方法都有各自的局限性:

当两个模型的输出差异较大时,它们就不太管用了

  • KL散度:会导致学生模型的输出变得过于平滑,失去了区分性;
  • RKL散度:会让学生的输出变得太简单,不能很好地模仿教师模型;
  • JS散度:会让学生模型低估稀有事件的概率;

而基于Sinkhorn距离的新方法能更准确地衡量和缩小教师模型和学生模型之间的差异,从而提高了学生模型的性能。

此外,研究还提出了一种基于批量的重构方法,从而在高维空间中捕捉跨样本分布的几何复杂性。

最终,通过在两个流行的自然语言处理测试集(GLUE和SuperGLUE)上测试,新方法在编码器、编码器-解码器以及解码器等不同架构的所有类型LLMs上均优于当前的最先进方法。

研究背景

知识蒸馏的提出是为了通过对齐教师模型的软目标(例如输出logits和中间层表示)来将教师模型内在固有的知识传递给学生模型。

给定训练集中的一个样本x_i及其真实标签𝐲𝑖 ∈ ℝ𝑑,来自教师模型𝑓𝐓和学生模型𝑓𝐒的输出logits 𝐭𝑖 ∈ ℝ𝑑和𝐒𝑖 ∈ ℝ𝑑可以由以下式子得到:

图片

其中为softmax函数, τ是温度参数, d是输出logits的维度。基于logit的知识蒸馏的目标是σΤ最小化测量散度J(𝐭𝑖,𝐒𝑖)以实现知识传递。

研究动机

现有研究已经尝试使用Kullback-Leibler(KL)散度、反Kullback-Leibler(RKL)散度和Jensen-Shannon(JS)散度。

所有这些度量都可以被视为f-散度度量的变体,而f-散度度量在量化缺乏实质性交集的任何两个分布时都存在明显局限性。

此外,每种度量都有其自身的缺陷:

KL蒸馏会导致模式平均,使学生学习到一个过于平滑的分布,涵盖了教师的整个支撑集;

RKL会引起模式塌陷,学生仅关注教师分布中高概率的显著区域,而忽视了其余部分;

JS蒸馏会产生模式低估,由于惩罚不足,学生会低估稀有事件的概率。

图片

为了解决传统散度度量的问题,研究做出了以下贡献:

  1. 提出了一种知识蒸馏方法SinKD,采用Sinkhorn距离作为散度度量。它不仅解决了KL、RKL和JS散度在极端场景下的局限性,而且避免了计算Wasserstein距离的负担。
  2. 深入探讨了Sinkhorn距离的性质,并将SinKD重新reformulated为batch-wise OT,扩展了它在NLP任务中的适用性。
  3. 通过大量的可比性、有效性和泛化性实验证明了SinKD相较于目前最先进的方法的优越性。并为实际应用提供了使用SinKD进行蒸馏的实用指导方针。

图片

传统散度度量的缺陷

图片

首先,KL散度是不对称的,表现为JKL(𝐭𝑖,𝐒𝑖)≠ JKL(𝐒𝑖,𝐭𝑖),这一性质违反了距离度量的对称性特性,从而引入了一些不一致性。

其次,由于使用KL损失进行优化,学生模型试图对教师模型的多模态分布进行平均化,从而导致对这些模式的拟合不足。这被称为“模式平均问题”(mode-averaging problem)。

因此,学生模型无法捕获数据中的所有关键模式,最终影响模型性能。

第三,KL散度对应的是一个非平滑函数,这为优化过程带来了挑战。

图片

与KL散度一样,具有内在的不对称性,从而导致在捕捉分布差异时出现不一致性。

此外,优化的学生模型倾向于仅关注教师分布中概率较高的事件,这被称为“模式崩塌问题”(mode-collapsing)。

如果教师对某个事件赋予零概率,学生模型也被迫做出相同的预测。

图片

其中m𝑖 = 1/2(𝐭𝑖+𝐒𝑖)受制于非平滑性,JS损失在优化过程中面临挑战。

另外,由于JS损失在低概率区域的匹配上惩罚不足,学生模型可能会过度低估稀有事件的概率。

对于分布之间重叠较少甚至完全不重叠的情况退化为常数时,还存在梯度消失的风险。

最优传输距离的优势

Wasserstein距离通过求解两个分布之间的最优传输计划来量化它们的差异。

直观地看,它可以被认为是将一个分布(即学生的logits分布)转换为另一个分布(即教师的logits分布)所需的最小“代价”,其中“代价”可以定义为被移动的质量与移动距离的乘积。

与传统的散度度量相比,Wasserstein距离作为蒸馏的成本函数更为合理,因为它不依赖于对被测量分布的隐式假设。此外,它几乎处处可微,从而便于优化。

另外,现有的散度度量只能独立处理每个样本对,进行逐一logit的匹配,对于一批样本,这些方法无法定位来自同一样本的教师和学生的logits对,从而无法实现整体距离的最小化。

由于计算Sinkhorn距离的过程可以实现来自同一样本的两个输出之间的精确逐元素匹配,研究提出了“批量化”的SinKD方法(batchified SinKD)。

通过这种方式,即使通过低维观测,也能够捕捉复杂且隐式分布的几何结构。

方法介绍

这里简要介绍SinKD的核心方法,详细推导过程可以参阅原论文。

批量重构的Sinkhorn距离

对于本问题,Wasserstein距离的定义如下:

图片

其中,

图片

Wasserstein距离本身在解析计算上存在困难,其计算成本对于蒸馏大型语言模型来说高得难以承受。

在这种情况下,研究使用Sinkhorn距离作为一种高效的近似方法。它不仅保留了Wasserstein距离的所有优点,同时也大大缓解了其在在线蒸馏中所面临的成本问题。

Sinkhorn距离的定义如下:

图片

逐样本蒸馏将每个实例独立处理,但忽略了一个批次样本中的整体趋势。

研究摒弃了仅在每对教师-学生样本对上工作的逐样本知识蒸馏方法,转而在教师和学生样本组上执行知识蒸馏。

一个包含b个样本的批次会整体参与散度度量。通过批量重构,这种方法有效地增加了“观测”空间的维度,特别是在d远小于b的情况下表现尤为显著。

对于常规分类任务的蒸馏,研究使用如下“batchified”代价函数:

图片

并初始化如下候选传输矩阵:

图片

通过重构和化简,研究可以使用如下迭代式计算最优传输矩阵(具体推导过程参见论文)

图片

由此,可以算出最优传输距离

图片

SinKD的变体

拓展到回归任务:对于回归任务,模型不会为每个选项生成概率,而是仅生成一个标量(d=1)。对于一个包含b个样本的批次,教师模型和学生模型的输出分别表示为𝐭 ∈ ℝbx1和𝐒 ∈ ℝbx1

为了计算教师和学生之间的批量化Sinkhorn距离,成本矩阵的元素由“批量化”回归输出之间的绝对差值确定:

图片

拓展到独热标签微调:SinKD方法也适用于仅有独热(one-hot)标签且无法获取教师模型logits的模型微调。

在这种情况下,可以将单热标签视为“假想”的单热教师模型的logits。由于单热logits中以零为主,传统的散度度量(例如KL散度)在处理这种极端情况下的散度量化时显得无能为力。

实验与分析

(1)数值结果。与基线和SOTA方法对比,论文方法在大部分任务上均取得了更好的性能。

图片

(2)消融实验。得出的结论如下:

  • Sinkhorn损失在所有损失中对学生模型的收益最大
  • 批量化的SinKD优于逐样本的SinKD
  • SinKD超越了基于f-散度变体的蒸馏方法

图片
图片
图片

(3)生成式大语言模型实验。SinKD可以推广到生成式大语言模型,并在基于类GPT架构的模型的蒸馏上取得不俗的成绩表现。

但同时研究也观察到,蒸馏效果的影响会随着PROMPT模板的变化而改变。

这意味着,同样的任务设置下,更加合理的PROMPT设计能够更充分地利用教师模型的固有知识。

图片

(4)可视化结果如下。

图片

为了增强内在评估,研究还进行了以下附加分析:

  • 隐藏状态的表示
  • 注意力机制的模式
  • 层级性能分析

(5)拓展到独热标签微调。与现有的散度度量方法(例如KL散度)不同,SinKD方法还可以扩展用于使用独热标签 (one-hot label) 微调语言模型。

图片

(6)拓展到计算机视觉领域深度网络。SinKD在所有测试的配置中均稳定地超越了所有基线方法。

图片

总结

研究引入了SinKD以解决现有蒸馏方法的局限性。此外,作者们提出了基于批次的重构方法,以捕捉高维空间中样本分布的几何复杂性。最后,研究在各类任务、数据集和模型架构上进一步验证SinKD的有效性。

更多细节欢迎查阅原论文。

COLING 2024会议论文:https://arxiv.org/abs/2402.17110
IEEE TNNLS期刊论文:https://hal.science/hal-04803835

责任编辑:张燕妮 来源: 量子位
相关推荐

2024-12-02 10:40:00

AI模型

2024-06-17 07:10:00

2022-04-08 14:40:59

框架训练模型

2017-03-23 17:09:45

2024-07-19 08:00:00

深度学习知识蒸馏

2022-06-02 10:29:23

神经网络AI计算机

2024-06-26 14:50:52

2022-12-19 15:16:46

机器学习模型

2022-11-22 10:07:32

研究模型

2024-11-15 10:00:00

2009-11-11 10:09:47

Linux LiveLinux操作系统

2013-06-19 11:32:32

计算性能ISCHPC

2024-08-23 09:20:00

AI语言模型

2024-04-07 09:00:00

数据模型

2024-01-25 10:19:10

2022-12-09 10:19:29

汽车行业数字化转型

2024-09-29 10:40:00

数据模型

2023-12-04 13:23:00

数据训练

2023-09-01 14:49:09

AI微软

2024-12-02 01:10:04

神经网络自然语言DNN
点赞
收藏

51CTO技术栈公众号