如何为您的机器学习问题选择正确的预训练模型

新闻 人工智能
迁移学习是一种机器学习技术,你可以使用一个预训练好的神经网络来解决一个问题,这个问题类似于网络最初训练用来解决的问题。

[[264482]]

 在这篇文章中,我们将简要介绍一下迁移学习是什么,以及如何使用它。

什么是迁移学习?

迁移学习是使用预训练模型解决深度学习问题的艺术。

迁移学习是一种机器学习技术,你可以使用一个预训练好的神经网络来解决一个问题,这个问题类似于网络最初训练用来解决的问题。例如,您可以利用构建好的用于识别狗的品种的深度学习模型来对狗和猫进行分类,而不是构建您自己的模型。这可以为您省去寻找有效的神经网络体系结构的痛苦,可以为你节省花在训练上的时间,并可以保证有良好的结果。也就是说,你可以花很长时间来制作一个50层的CNN来***地区分你的猫和狗,或者你可以简单地使用许多预训练好的图像分类模型。

使用预训练模型的三种不同方式

主要有三种不同的方式可以重新定位预训练模型。他们是,

  1. 特征提取 。
  2. 复制预训练的网络的体系结构。
  3. 冻结一些层并训练其他层。

特征提取:这里我们所需要做的就是改变输出层,以给出cat和dog的概率(或者您的模型试图将内容分类到的类的数量),而不是最初训练它将内容分类到的数千个类。当我们试图训练模型所使用的数据与预训练的模型最初所训练的数据非常相似且数据集的大小很小时,这是理想的。这种机制称为固定特征提取。我们只对添加的新输出层进行重新训练,并保留每一层的权重。

复制预训练网络的架构 :在这里,我们定义了一个与预训练模型具有相同体系结构的机器学习模型,该模型在执行与我们试图实现的任务类似的任务时显示了出色的结果,并从头开始训练它。我们从预训练的模型中丢弃每一层的权重,然后根据我们的数据重新训练整个模型。当我们有大量的数据要训练时,我们会采用这种方法,但它与训练前的模型所训练的数据并不十分相似。

冻结一些层并训练其他层:我们可以选择冻结一个预训练模型的初始k层,只训练最顶层的n-k层。我们保持初始值的权重与预训练模型的权重相同且不变,并对数据的高层进行再训练。当数据集较小且数据相似度较低时,采用该方法。较低的层主要关注可以从数据中提取的最基本的信息,因此可以将其用于其他问题,因为基本级别的信息通常是相同的。

另一种常见情况是数据相似性高且数据集也很大。在这种情况下,我们保留模型的体系结构和模型的初始权重。然后,我们对整个模型进行再训练,以更新预训练模型的权重,以更好地适应我们的特定问题。这是使用迁移学习的理想情况。

下图显示了随着数据集大小和数据相似性的变化而采用的方法。

迁移学习:如何为您的机器学习问题选择正确的预训练模型

PyTorch中的迁移学习

在torchvision.models模块下,PyTorch中有八种不同的预训练模型。他们是 :

  1. AlexNet
  2. VGG
  3. RESNET
  4. SqueezeNet
  5. DenseNet
  6. Inception v3
  7. GoogLeNet
  8. ShuffleNet v2

这些都是为图像分类而构建的卷积神经网络,在ImageNet数据集上进行训练。ImageNet是根据WordNet层次结构组织的图像数据库,包含14,197,122张属于21841类的图像。

迁移学习:如何为您的机器学习问题选择正确的预训练模型

由于PyTorch中的所有预训练模型都针对相同的任务在相同的数据集上进行训练,所以我们选择哪一个并不重要。让我们选择ResNet网络,看看如何在前面讨论的不同场景中使用它。

用于图像识别的ResNet或深度残差学习在pytorch、ResNet -18、ResNet -34、ResNet -50、ResNet -101和ResNet -152上有五个版本。

让我们从torchvision下载ResNet-18。

  1. import torchvision.models as models 
  2. model = models.resnet18(pretrained=True) 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

以下是我们刚刚下载的模型。

迁移学习:如何为您的机器学习问题选择正确的预训练模型
迁移学习:如何为您的机器学习问题选择正确的预训练模型

现在,让我们看看尝试,看看如何针对四个不同的问题训​​练这个模型。

数据集很小,数据相似性很高

考虑这个kaggle数据集(https://www.kaggle.com/mriganksingh/cat-images-dataset)。这包括猫的图像和其他非猫的图像。它有209个像素64*64*3的训练图像和50个测试图像。这显然是一个非常小的数据集,但我们知道ResNet是在大量动物和猫图像上训练的,所以我们可以使用ResNet作为固定特征提取器来解决我们的猫与非猫的问题。

  1. num_ftrs = model.fc.in_features 
  2. num_ftrs 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

Out: 512

  1. model.fc.out_features 

Out: 1000

我们需要冻结除***一层之外的所有网络。我们需要设置requires_grad = False来冻结参数,这样就不会在backward()中计算梯度。新构造模块的参数默认为requires_grad=True。

  1. for param in model.parameters(): 
  2.  param.requires_grad = False 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

由于我们只需要***一层提供两个概率,即图像的概率是否为cat,我们可以重新定义***一层中的输出特征数。

  1. model.fc = nn.Linear(num_ftrs, 2

这是我们模型的新架构。

迁移学习:如何为您的机器学习问题选择正确的预训练模型
迁移学习:如何为您的机器学习问题选择正确的预训练模型

我们现在要做的就是训练模型的***一层,我们将能够使用我们重新定位的vgg16来预测图像是否是猫,而且数据和训练时间都非常少。

数据的大小很小,数据相似性也很低

考虑来自(https://www.kaggle.com/kvinicki/canine-coccidiosis),这个数据集包含了犬异孢球虫和犬异孢球虫卵囊的图像和标签,异孢球虫卵囊是一种球虫寄生虫,可感染狗的肠道。它是由萨格勒布兽医学院创建的。它包含了两种寄生虫的341张图片。

迁移学习:如何为您的机器学习问题选择正确的预训练模型

这个数据集很小,而且不是Imagenet中的一个类别。在这种情况下,我们保留预先训练好的模型架构,冻结较低的层并保留它们的权重,并训练较低的层更新它们的权重以适应我们的问题。

  1. count = 0 
  2. for child in model.children(): 
  3.  count+=1 
  4. print(count) 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

Out: 10

ResNet18共有10层。让我们冻结前6层。

  1. count = 0 
  2. for child in model.children(): 
  3.  count+=1 
  4.  if count < 7
  5.  for param in child.parameters(): 
  6.  param.requires_grad = False 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

现在我们已经冻结了前6层,让我们重新定义最终输出层,只给出2个输出,而不是1000。

  1. model.fc = nn.Linear(num_ftrs, 2

这是更新的架构。

迁移学习:如何为您的机器学习问题选择正确的预训练模型
迁移学习:如何为您的机器学习问题选择正确的预训练模型

现在,训练这个机器学习模型,更新***4层的权重。

数据集的大小很大,但数据相似性非常低

考虑这个来自kaggle,皮肤癌MNIST的数据集:HAM10000

其具有超过10015个皮肤镜图像,属于7种不同类别。这不是我们在Imagenet中可以找到的那种数据。

这就是我们只保留模型架构而不保留来自预训练模型的任何权重的地方。让我们重新定义输出层,将项目分类为7个类别。

  1. model.fc = nn.Linear(num_ftrs, 7

这个模型需要几个小时才能在没有GPU的机器上进行训练,但是如果你运行足够的时代,你仍然会得到很好的结果,而不必定义你自己的模型架构。

数据大小很大,数据相似性很高

考虑来自kaggle 的鲜花数据集(https://www.kaggle.com/alxmamaev/flowers-recognition)。它包含4242个花卉图像。图片分为五类:洋甘菊,郁金香,玫瑰,向日葵,蒲公英。每个类大约有800张照片。

这是应用迁移学习的理想情况。我们保留了预训练模型的体系结构和每一层的权重,并训练模型更新权重以匹配我们的特定问题。

  1. model.fc = nn.Linear(num_ftrs, 5
  2. best_model_wts = copy.deepcopy(model.state_dict()) 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

我们从预训练的模型中复制权重并初始化我们的模型。我们使用训练和测试阶段来更新这些权重。

  1. for epoch in range(num_epochs): 
  2.   
  3.  print(‘Epoch {}/{}’.format(epoch, num_epochs — 1)) 
  4.  print(‘-’ * 10
  5.  for phase in [‘train’, ‘test’]: 
  6.   
  7.  if phase == 'train'
  8.  scheduler.step() 
  9.  model.train()  
  10.  else
  11.  model.eval() 
  12.  running_loss = 0.0 
  13.  running_corrects = 0 
  14.  for inputs, labels in dataloaders[phase]: 
  15.   
  16.  inputs = inputs.to(device) 
  17.  labels = labels.to(device) 
  18.  optimizer.zero_grad() 
  19.  with torch.set_grad_enabled(phase == ‘train’): 
  20.   
  21.  outputs = model(inputs) 
  22.  _, preds = torch.max(outputs, 1
  23.  loss = criterion(outputs, labels) 
  24.   
  25.  if phase == ‘train’: 
  26.  loss.backward() 
  27.  optimizer.step() 
  28.  running_loss += loss.item() * inputs.size(0
  29.  running_corrects += torch.sum(preds == labels.data) 
  30.   
  31.  epoch_loss = running_loss / dataset_sizes[phase] 
  32.  epoch_acc = running_corrects.double() / dataset_sizes[phase] 
  33.  print(‘{} Loss: {:.4f} Acc: {:.4f}’.format( 
  34.  phase, epoch_loss, epoch_acc)) 
  35.   
  36.  if phase == ‘test’ and epoch_acc > best_acc: 
  37.  best_acc = epoch_acc 
  38.  best_model_wts = copy.deepcopy(model.state_dict()) 
  39. print(‘Best val Acc: {:4f}’.format(best_acc)) 
  40. model.load_state_dict(best_model_wts) 
迁移学习:如何为您的机器学习问题选择正确的预训练模型

这种机器学习模式也需要几个小时的训练,但即使只有一个训练epoch ,也会产生出色的效果。

您可以按照相同的原则在任何其他平台上使用任何其他预训练的网络执行迁移学习。本文随机挑选了Resnet和pytorch。任何其他CNN都会给出类似的结果。希望这可以节省您使用计算机视觉解决现实世界问题的痛苦时间。

责任编辑:张燕妮 来源: 头条科技
相关推荐

2022-10-31 15:04:59

2021-03-15 07:55:55

API网关微服务架构

2022-04-27 18:20:19

综合布线交换机网络

2017-02-28 14:17:03

机器学习算法

2024-11-04 00:24:56

2021-06-25 10:23:34

RPA软件机器人流程自动化机器学习

2017-11-09 08:51:28

2017-12-26 13:53:31

深度学习迁移学习

2017-11-24 09:30:58

数据库微服务云架构

2017-03-24 15:58:46

互联网

2020-03-04 13:53:25

物联网协议物联网IOT

2023-05-29 15:53:32

DevOps架构自动化

2015-06-08 10:07:04

公有云云服务商选择公有云迁移

2019-10-12 10:11:02

数据集聚类算法

2023-08-09 17:43:40

光纤电缆光纤终端盒

2022-09-19 15:37:51

人工智能机器学习大数据

2018-07-03 15:26:35

算法机器学习数据

2021-03-28 17:14:38

数据库APP技术

2012-10-30 09:28:52

2009-03-04 11:29:24

ibmdwJava
点赞
收藏

51CTO技术栈公众号