深度学习框架Flash如何仅用几行代码构建图像分类器?

译文
人工智能 深度学习
图像分类是我们想要预测哪个类别属于图像的任务。由于图像表示,这项任务很困难。如果我们将图像铺平,它会创建一个长长的一维向量。此外,该表示将丢失相邻信息。因此,我们需要深度学习来提取特征并预测结果。

[[412621]]

【51CTO.com快译】一、简介

图像分类是我们想要预测哪个类别属于图像的任务。由于图像表示,这项任务很困难。如果我们将图像铺平,它会创建一个长长的一维向量。此外,该表示将丢失相邻信息。因此,我们需要深度学习来提取特征并预测结果。

有时,构建深度学习模型会成为一项艰巨的任务。虽然我们为图像分类创建了一个基础模型,但需要花大量时间来创建代码。我们必须准备好用于准备数据、训练模型并测试模型的代码,并将模型部署到服务器上。这时Flash就有了用武之地!

Flash是一种高级深度学习框架,用于快速构建、训练和测试深度学习模型。Flash基于PyTorch框架。所以如果您了解PyTorch,就会很熟悉Flash。

与PyTorch和Lighting相比,Flash易于使用,但不如以前的库灵活。如果您想构建更复杂的模型,可以使用Lightning或直接使用PyTorch。

借助Flash,您可以用几行代码构建深度学习模型!因此,如果您刚接触深度学习,别害怕。Flash可以帮助您构建深度学习模型,不会因代码而感到困惑。

本文将介绍如何使用Flash构建图像分类器。

二、实施

安装库

想安装库,您可以使用pip命令,如下所示:

  1. pip install lightning-flash 

如果该命令不起作用,可以使用其GitHub存储库安装该库。命令如下所示:

  1. pip install git+https://github.com/PyTorchLightning/lightning-flash.git 

在我们可以成功下载软件包之后,现在可以加载库。我们还将种子设为编号42。这是执行此操作的代码:

  1. from pytorch_lightning import seed_everything 
  2.  
  3. import flash 
  4. from flash.core.classification import Labels 
  5. from flash.core.data.utils import download_data 
  6. from flash.image import ImageClassificationData, ImageClassifier 
  7.  
  8. set the random seeds. 
  9. seed_everything(42) 
  10. Global seed set to 42  
  11. 42 

下载数据

安装完库后,现在不妨获取数据。出于演示需要,我们将使用名为Cat和Dog数据集的数据集。

该数据集含有两个类别:猫和狗的图像。想访问数据集,您可以在Kaggle找到该数据集。可以在此处访问数据集。

加载数据

下载数据后,不妨将数据集加载到一个对象中。我们将使用from_folders方法将数据放入到ImageClassification对象中。这是执行此操作的代码:

  1. datamodule = ImageClassificationData.from_folders( 
  2.     train_folder="cat_and_dog/training_set"
  3.     val_folder="cat_and_dog/validation_set"

加载模型

我们加载数据后,下一步就是加载模型。由于我们不会从头开始构建自己的架构,将使用基于现有卷积神经网络架构的预训练模型。

我们将使用已经过预训练的ResNet-50模型。此外,我们基于数据集设置类别的数量。这是执行此操作的代码:

  1. model = ImageClassifier(backbone="resnet50", num_classes=datamodule.num_classes) 

训练模型

加载模型后,现在不妨训练模型。我们需要先初始化Trainer对象。我们将用3个轮次(epoch)训练模型。此外,我们启用GPU以训练模型。这是执行此操作的代码:

  1. trainer = flash.Trainer(max_epochs=3, gpus=1) 
  2. GPU available: True, used: True TPU available: False, using: 0 TPU cores 

初始化对象后,不妨训练模型。为训练模型,我们可以使用一个名为finetune的函数。在函数里面,我们设置模型和数据。此外,我们将训练策略设置为freeze(冻结),这表明我们不想训练特征提取器。换句话说,我们只训练分类器部分。

这是执行此操作的代码:

  1. trainer.finetune(model, datamodule=datamodule, strategy="freeze"
  2. LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ---------------------------------------- 0 | metrics | ModuleDict | 0 1 | backbone | Sequential | 23.5 M 2 | head | Sequential | 4.1 K ---------------------------------------- 57.2 K Trainable params 23.5 M Non-trainable params 23.5 M Total params 94.049 Total estimated model params size (MB)  
  3. Validation sanity check: 0it [00:00, ?it/s] 
  4. Global seed set to 42  
  5. Training: 0it [00:00, ?it/s] 
  6. Validating: 0it [00:00, ?it/s] 
  7. Validating: 0it [00:00, ?it/s] 
  8. Validating: 0it [00:00, ?it/s] 

这是评估结果:

从结果中可以看出,我们的模型其准确率达到了约97%。不赖!现在不妨拿几个新数据测试模型。

测试模型

我们将使用针对该模型没有训练过的样本数据。以下是我们将测试模型的样本:

  1. import matplotlib.pyplot as plt 
  2. from PIL import Image 
  3.  
  4. fig, ax = plt.subplots(1, 5, figsize=(40,8)) 
  5. for i in range(5): 
  6.     ax[i].imshow(Image.open(f'cat_and_dog/testing/{i+1}.jpg')) 
  7. plt.show() 

为了测试模型,我们可以使用flash库中的predict方法。这是执行此操作的代码:

  1. model.serializer = Labels() 
  2.  
  3. predictions = model.predict(["cat_and_dog/testing/1.jpg"
  4.                              "cat_and_dog/testing/2.jpg"
  5.                              "cat_and_dog/testing/3.jpg"
  6.                              "cat_and_dog/testing/4.jpg"
  7.                              "cat_and_dog/testing/5.jpg"]) 
  8. print(predictions) 
  9. ['dogs''dogs''cats''cats''dogs'

从上面的结果可以看出,模型预测了带有正确标签的样本。很好!不妨保存模型以备后用。

保存模型

我们已训练并测试了模型。不妨使用save_checkpoint方法保存模型。这是执行此操作的代码:

  1. trainer.save_checkpoint("cat_dog_classifier.pt"

如果您想针对其他代码加载模型,可以使用load_from_checkpoint方法。这是执行此操作的代码:

  1. model = ImageClassifier.load_from_checkpoint("cat_dog_classifier.pt"

三、结语

做得好!您已学习了如何使用Flash构建图像分类器。正如文章开头所说,它只需要几行代码!是不是很酷?

但愿本文可以帮助您根据自己的情况构建自己的深度学习模型。如果您想实施一个更复杂的模型,但愿能开始学习 PyTorch。

原文标题:How to Build An Image Classifier in Few Lines of Code with Flash,作者:Irfan Alghani Khalid

【51CTO译稿,合作站点转载请注明原文译者和出处为51CTO.com】

 

责任编辑:华轩 来源: 51CTO
相关推荐

2021-10-18 09:09:16

数据库

2018-06-19 08:35:51

情感分析数据集代码

2022-12-30 08:00:00

深度学习集成模型

2018-04-09 10:20:32

深度学习

2023-02-28 08:00:00

深度学习神经网络人工智能

2016-12-27 15:33:25

softmax分类器课程

2020-08-10 06:36:21

强化学习代码深度学习

2015-02-09 10:43:00

JavaScript

2017-09-09 06:04:22

深度学习人物图像神经网络

2018-05-28 13:12:49

深度学习Python神经网络

2022-04-01 09:30:00

开源AutoXGBAPI

2024-09-11 08:34:28

2019-04-01 05:42:24

JavaScript视觉程序代码

2018-07-19 15:13:15

深度学习图像

2017-05-12 16:25:44

深度学习图像补全tensorflow

2024-09-29 09:32:58

2021-11-02 11:48:39

深度学习恶意软件观察

2022-11-11 15:07:50

深度学习函数鉴别器

2022-09-29 23:53:06

机器学习迁移学习神经网络

2019-12-05 09:50:54

GitHub 技术深度学习
点赞
收藏

51CTO技术栈公众号