被DeepSeek带火的知识蒸馏详解!

人工智能
知识蒸馏是一种模型压缩技术,通过训练一个小而高效的学生模型来模仿一个预训练的大且复杂的教师模型(或一组模型)的行为。这种训练设置通常被称为“教师-学生”模式,其中大型模型作为教师,小型模型作为学生。教师模型的知识通过最小化损失函数传递给学生模型,目标是匹配教师模型预测的类概率分布。

今天来详细了解DeepSeek中提到的知识蒸馏技术,主要内容来自三巨头之一Geoffrey Hinton的一篇经典工作:https://arxiv.org/pdf/1503.02531。

主要从背景、定义、原理、代码复现等几个方面来介绍:

1、背景介绍

训练与部署的不一致性

在机器学习和深度学习领域,训练模型和部署模型通常存在显著差异。训练阶段,为了追求最佳性能,我们通常会使用复杂的模型架构和大量的计算资源,从海量且高度冗余的数据集中提取有用信息。例如,一些最先进的模型可能包含数十亿甚至上百亿的参数,或者通过多个模型集成来进一步提升性能。然而,这些庞大的模型在实际部署时面临诸多问题:

  • 推断速度慢大模型在处理数据时需要更多的时间来完成计算,这在需要实时响应的场景中是不可接受的。
  • 资源要求高大模型需要大量的内存和显存来存储模型参数和中间计算结果,这使得它们难以部署在资源受限的设备上,如移动设备或嵌入式系统。

因此,在部署阶段,我们对模型的延迟和计算资源有着严格的限制。这就引出了模型压缩的需求——在尽量不损失性能的前提下,减少模型的参数量,使其更适合实际应用环境。

模型压缩与知识蒸馏

模型压缩用于解决训练阶段与部署阶段之间的不一致性,特别是在模型规模与实际应用需求之间的矛盾,在尽量不损失模型性能的前提下,减少模型的参数量和计算复杂度,使其更适合在资源受限的环境中部署。

知识蒸馏(Knowledge Distillation)是其中一种非常有效的模型压缩技术。

2、什么是知识蒸馏?

图片

知识蒸馏是一种模型压缩技术,通过训练一个小而高效的学生模型来模仿一个预训练的大且复杂的教师模型(或一组模型)的行为。这种训练设置通常被称为“教师-学生”模式,其中大型模型作为教师,小型模型作为学生。教师模型的知识通过最小化损失函数传递给学生模型,目标是匹配教师模型预测的类概率分布。

知识蒸馏的核心思想是将一个复杂且性能强大的“教师”模型的知识迁移到一个更小、更轻量的“学生”模型中。通过这种方式,学生模型可以在保持较小参数量的同时,尽可能地继承教师模型的性能。

该方法最早由Bucila等人在2006年提出(https://www.cs.cornell.edu/~caruana/compression.kdd06.pdf),并在2015年由Hinton等人推广,成为知识蒸馏领域的奠基之作。Hinton的工作引入了带温度参数的softmax函数,进一步增强了知识转移的有效性。

在 Geoffrey Hinton 等人的论文《Distilling the Knowledge in a Neural Network》中,作者通过昆虫的幼虫和成虫形态来比喻机器学习中的训练阶段和部署阶段。这种比喻强调了不同阶段对于模型的不同需求:就像昆虫的幼虫形态专注于从环境中吸收能量和养分,而其成虫形态则优化于移动和繁殖;同样地,在机器学习中,训练阶段需要处理大量数据以提取结构信息,而部署阶段则更关注于实时性和计算资源的有效利用。

教师与学生模型

教师模型:指的是那些庞大、复杂且可能计算成本高昂的模型或模型集合(ensemble)。这些模型虽然在训练阶段能够提供优秀的性能,但由于其复杂性,不适合直接用于实际部署。教师模型通常经过充分训练,可以很好地泛化到新数据上。

学生模型:相对较小且计算效率更高的模型,旨在模仿教师模型的表现同时保持低延迟和高效能,便于大规模部署。学生模型通过“知识蒸馏”过程从教师模型那里获得知识,从而能够在资源受限的环境下运行。

知识的理解

关于“知识”的理解存在一个常见的误解,即将其简单等同于模型的权重参数。然而,Hinton 等人提倡一种更为抽象的观点,即知识应被视为从输入向量到输出向量的学习映射关系。这不仅包括了对正确答案的概率预测,还涵盖了对不正确答案之间细微差异的理解。例如,在图像分类任务中,即使某个类别的概率非常小,它与其他错误类别相比仍然可能存在显著差异,这些差异反映了模型如何泛化的关键信息。

蒸馏技术

知识蒸馏是一种技术,它允许我们将教师模型中的知识转移到学生模型中。这不仅仅是复制权重的过程,而是涉及到使用教师模型生成的软目标(soft targets)作为指导,帮助学生模型学习到相似的泛化能力。通过调整 softmax 层的温度参数,可以使这些软目标更加平滑,从而让学生模型能够捕捉到更多有用的信息,并减少过拟合的风险。

3、知识蒸馏的工作原理

Soft Target vs Hard Target

Hard Targets是最常见的训练标签形式,通常指的是每个训练样本对应的一个确切的类别标签。例如,在分类问题中,如果任务是对图像进行数字识别(如MNIST数据集),那么hard target就是一个具体的数字(比如1)。在这种情况下,模型被训练来最大化正确类别的概率,而其他所有类别的概率则应尽可能地小。这通常通过最小化交叉熵损失函数来实现,该函数惩罚了模型对正确标签预测的概率不足,并且不考虑错误标签之间的相对概率。

相比之下,Soft Targets提供了一个更加细致的概率分布,不仅包含了正确的类别,还包括了模型认为可能相关的其他类别的概率。这意味着除了给出最有可能的类别之外,soft targets还提供了关于模型对于哪些类别可能是正确的、以及这些类别之间如何相互关联的信息。这种类型的标签可以通过一个已经训练好的教师模型生成,它为每个输入产生一个概率分布而不是单一的类别标签。

图片

基于soft targets训练模型相较于hard targets训练模型,尤其是在资源受限的情况下有效地转移复杂模型的知识,具有以下几个显著的优势和意义:

  1. 传递更丰富的信息:Soft targets不仅包含正确类别的概率,还包括了对其他类别的相对概率估计。这意味着模型可以学习到不同类别之间的相似性和差异性,使得即使是对于错误的类别预测,也能反映出它们之间的细微差别,从而提供比单一正确答案更多的信息。例如,在图像识别任务中,一个特定图像可能与多个类别有一定的相似性,而这些信息在hard target中是丢失的。
  2. 减少梯度方差:当使用较高的温度参数T时,softmax输出的概率分布变得更加平滑,这意味着每个样本提供的信息量相对于hard target来说增加了,同时减少了梯度估计中的方差。这对于小数据集上的训练尤其有益,因为它可以防止过拟合并帮助模型更好地泛化到未见过的数据。
  3. 提高泛化能力:相比于hard targets仅给出一个确切的类别标签,通过模仿教师模型(通常是大型或集成模型)生成的soft targets,学生模型能够学习到教师模型如何进行泛化的细节。这有助于模型更好地学习数据中的潜在结构和模式,特别是在处理复杂或模糊边界的问题时。对于那些难以明确区分的类别,soft targets能够指导模型认识到哪些错误类别之间更加接近,从而改进其泛化能力。
  4. 加速收敛和降低过拟合风险:由于soft targets通常比hard targets拥有更高的熵,它们能提供更多的信息,并且减少梯度估计的方差。这对于小规模数据集特别有用,因为它可以帮助防止模型过度拟合训练数据中的噪声。论文中提到的一个实验显示,使用soft targets的学生模型即使只用3%的数据进行训练,也能够几乎恢复全量数据所能提供的信息,同时不需要早期停止来防止过拟合。这表明soft targets作为一种有效的正则化方法,能够帮助模型更好地泛化。
    图片
  5. 增强模型的鲁棒性:通过利用soft targets进行训练,模型可以学习到输入数据的内在分布特性,而不仅仅是表面特征。这意味着模型对于输入数据的小变化(如轻微的图像变换)会更加稳健,因为它们已经学会了识别那些对最终分类决策影响较小的变化。

带温度的Softmax

传统softmax函数倾向于产生极端的概率分布,导致非正确类别的概率接近于零,这限制了其对模型训练的帮助。

图片

如图所示,当输入一张马的图片时,对于未调整温度(默认为1)的 Softmax 输出,正标签的概率接近 1,而负标签的概率接近 0。这种尖锐的分布对学生模型不够友好,因为它只提供了关于正确答案的信息,而忽略了错误答案的信息。即驴比汽车更像马,识别为驴的概率应该大于识别为汽车的概率。

图片

带温度的Softmax是一种调整softmax函数输出的方法,通过引入一个额外参数——温度(temperature, T),来控制输出概率分布的平滑度。这个概念在知识蒸馏中尤为重要,因为它能够影响教师模型如何向学生模型传递知识。

图片

  • 当T=1时,该公式退化为标准的softmax。
  • 当T>1时,输出的概率分布变得更平滑,即不同类别的概率差异减小,这使得即使是不太可能的类别也会分配到一定的概率值。
  • 当T<1时,结果是相反的,输出的概率分布变得更加尖锐,增加了最有可能类别的概率,同时进一步降低了其他类别的概率。

关于T的取值,原文说是“不要太大,也不要太小”,如果T的值太小,就会导致原本概率值比较小的类别除以后还是比较小,就获取不到数据中有效的信息。如果T的值太大,负标签带的信息会被放大,就会引入噪声,也就是将一些没有用信息给放的很大。这个实际还是调参吧,记得之前看到过文章说一般是20以内。当需要考虑负标签之间的关系时,可以采用较大的温度。例如,在自然语言处理任务中,模型可能需要学习到“猫”和“狗”之间的相似性,而不仅仅是它们的硬标签。在这种情况下,较大的温度可以使模型更好地捕捉到这些关系。反之,如果为了消除负标签中噪声的影响,可以采用较小的温度。

图片

蒸馏过程

准备阶段

  1. 一个已经预训练好的、泛化能力强的教师网络。
  2. 构造好数据集,这个数据集可以使用用来训练教师网络的数据集,也可以专门准备一个数据集来做蒸馏的过程。
  3. 搭建好学生网络。

蒸馏过程

  1. 输入数据通过教师网络:  输入数据首先通过教师网络得到logits(即softmax层之前的输出)。使用较高的温度t对这些logits应用softmax函数,生成soft labels。
  2. 输入数据通过学生网络:  同样的输入数据也通过学生网络得到自己的logits。  这些logits分别通过相同的温度t和标准温度T=1的softmax函数处理,得到soft predictions和hard predictions。Soft predictions用于与教师网络的soft labels比较,而hard predictions则用于与真实标签比较。
  3. 计算损失:
    Distillation Loss:使用交叉熵损失函数计算教师网络的soft labels和学生网络的soft predictions之间的差异。这一步帮助学生网络学习到教师网络对于不同类别间细微差别的理解。
    Student Loss:同样使用交叉熵损失函数计算学生网络的hard predictions与真实标签之间的差异。这确保了学生网络也能直接从数据中学习。
    最终损失:这两个损失(distillation loss和student loss)通常会加权求和形成最终的损失函数。权重的选择可以根据具体任务调整,但一般情况下,distillation loss的权重相对较小,以避免过度依赖教师网络的预测。
    图片

其中,α是介于0和1之间的一个系数,用来平衡两个损失项的重要性。作者发现学生网络的损失的权重小一点比较好,比如可以尝试 0.3、0.4。

最后作者说,由于图片的梯度大约是图片的梯度的图片,因此在图片前乘上图片可以保证两个损失部分的梯度量贡献基本一致。

图片

4、代码复现

使用mnist公开数据集简单复现了上述流程,知识蒸馏主要代码如下:

def train_kd_student(epoch):
    student_kd.train()
    optimizer = torch.optim.SGD(student_kd.parameters(), lr=lr, momentum=momentum)
    
    print(f'\nTraining KD Student Epoch: {epoch}')
    train_loss = 0
    correct = 0
    total = 0
    with tqdm(train_loader, desc=f"Training KD Student Epoch {epoch}", total=len(train_loader)) as pbar:
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            logits_student = student_kd(inputs)
            with torch.no_grad():
                logits_teacher = teacher_net(inputs)
            ce_loss = nn.CrossEntropyLoss()(logits_student, targets)
            kd_loss = loss(logits_student, logits_teacher, temperature=T)
            total_loss = ALPHA * ce_loss + BETA * kd_loss
            total_loss.backward()
            optimizer.step()
            train_loss += total_loss.item()
            _, predicted = logits_student.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            # 记录训练损失
            writers['student_kd'].add_scalar('Training Loss', ce_loss.item(), 
                                           epoch * len(train_loader) + batch_idx)
            writers['student_kd'].add_scalar('Training Loss', kd_loss.item(), 
                                           epoch * len(train_loader) + batch_idx)
            writers['student_kd'].add_scalar('Training Loss', total_loss.item(), 
                                           epoch * len(train_loader) + batch_idx)
            pbar.set_postfix(loss=train_loss/(batch_idx+1), acc=f"{100.*correct/total:.1f}%")
    # 记录训练准确率
    acc = 100. * correct / total
    writers['student_kd'].add_scalar('Training Accuracy', acc, epoch)
学生模型和教师模型代码如下:
class LeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=num_classes)
    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = self.maxpool(F.relu(self.conv2(x)))
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
class LeNetHalfChannel(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNetHalfChannel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=5)   
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=3 * 12 * 12, out_features=num_classes)   
    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        
        return x
    
# 初始化模型
teacher_net = LeNet().to(device=device)
student_plain = LeNetHalfChannel().to(device=device)  # 单独训练的学生
student_kd = LeNetHalfChannel().to(device=device)     # 知识蒸馏的学生
teacher_net.load_state_dict(torch.load('./model/model.pt'))
对比仅使用学生模型、仅使用教师模型和使用教师模型蒸馏学生模型的模型性能,结果如下:
Teacher Model Accuracy: 97.99% 
Plain Student Best Accuracy: 85.77% 
KD Student Best Accuracy: 96.70%

可以看到,使用知识蒸馏得到的模型性能十分接近教师模型,远超仅训练学生模型的性能。

完整工程代码地址:https://github.com/jinbo0906/awesome-model-compression/blob/main/knowledge%20distillation/kd.ipynb

5、DeepSeek-R1的蒸馏方法

DeepSeek-R1 的蒸馏过程基于其自身生成的合成推理数据。由DeepSeek-R1 模型生成的 800,000 个数据样本,对较小的基础模型(例如 Qwen 和 Llama 系列)仅进行监督微调,从而将大型模型的推理能力高效地迁移到小型模型中。

从这里可以了解到,在大模型时代,蒸馏不仅仅完全通过学习教师模型的软标签实现,也可以通过学习教师模型的输出结构化数据,从而学习到教师模型中一些强大的性能。这种方式对算力的需求大大减少,推动未来AI能力的普惠化。

这种方式降低了对模型训练方法的要求,但是对数据质量的要求则大大增加,从这里也可以看到,DeepSeek-R1模型生成数据的质量非常高,可能远远超过了人工标注,在DeepSeek-R1的论文中也提到模型的self-evolution能力,可能是未来非常值得关注的一个研究方向。

6、总结

知识蒸馏不仅能够有效地降低模型的复杂度和计算成本,还能保持较高的模型性能。通过对教师模型知识的巧妙迁移,学生模型能够在资源受限的环境中展现出色的表现。未来知识蒸馏将在更多领域发挥重要作用。

责任编辑:庞桂玉 来源: 小白学AI算法
相关推荐

2025-02-07 15:10:00

模型AI语音

2025-02-12 10:06:25

2022-12-19 15:16:46

机器学习模型

2023-11-15 09:24:00

数据训练

2025-02-17 08:00:00

DeepSeek模型AI

2023-12-11 14:21:00

模型训练

2024-07-19 08:00:00

深度学习知识蒸馏

2024-12-02 01:10:04

神经网络自然语言DNN

2023-04-17 07:33:05

ChatGPTAI工程师

2024-01-25 10:19:10

2025-02-14 09:17:16

2021-01-25 10:36:32

知识图谱人工智能

2024-01-22 09:02:00

AI训练

2015-03-13 11:26:57

两会云计算云概念

2021-10-14 06:29:56

薪资举报机制

2024-12-04 09:15:00

AI模型

2025-02-13 11:00:30

2023-09-01 14:49:09

AI微软

2024-02-23 09:02:21

前端开源项目
点赞
收藏

51CTO技术栈公众号