【51CTO.com快译】一、简介
图像分类是我们想要预测哪个类别属于图像的任务。由于图像表示,这项任务很困难。如果我们将图像铺平,它会创建一个长长的一维向量。此外,该表示将丢失相邻信息。因此,我们需要深度学习来提取特征并预测结果。
有时,构建深度学习模型会成为一项艰巨的任务。虽然我们为图像分类创建了一个基础模型,但需要花大量时间来创建代码。我们必须准备好用于准备数据、训练模型并测试模型的代码,并将模型部署到服务器上。这时Flash就有了用武之地!
Flash是一种高级深度学习框架,用于快速构建、训练和测试深度学习模型。Flash基于PyTorch框架。所以如果您了解PyTorch,就会很熟悉Flash。
与PyTorch和Lighting相比,Flash易于使用,但不如以前的库灵活。如果您想构建更复杂的模型,可以使用Lightning或直接使用PyTorch。
借助Flash,您可以用几行代码构建深度学习模型!因此,如果您刚接触深度学习,别害怕。Flash可以帮助您构建深度学习模型,不会因代码而感到困惑。
本文将介绍如何使用Flash构建图像分类器。
二、实施
安装库
想安装库,您可以使用pip命令,如下所示:
- pip install lightning-flash
如果该命令不起作用,可以使用其GitHub存储库安装该库。命令如下所示:
- pip install git+https://github.com/PyTorchLightning/lightning-flash.git
在我们可以成功下载软件包之后,现在可以加载库。我们还将种子设为编号42。这是执行此操作的代码:
- from pytorch_lightning import seed_everything
- import flash
- from flash.core.classification import Labels
- from flash.core.data.utils import download_data
- from flash.image import ImageClassificationData, ImageClassifier
- # set the random seeds.
- seed_everything(42)
- Global seed set to 42
- 42
下载数据
安装完库后,现在不妨获取数据。出于演示需要,我们将使用名为Cat和Dog数据集的数据集。
该数据集含有两个类别:猫和狗的图像。想访问数据集,您可以在Kaggle找到该数据集。可以在此处访问数据集。
加载数据
下载数据后,不妨将数据集加载到一个对象中。我们将使用from_folders方法将数据放入到ImageClassification对象中。这是执行此操作的代码:
- datamodule = ImageClassificationData.from_folders(
- train_folder="cat_and_dog/training_set",
- val_folder="cat_and_dog/validation_set",
- )
加载模型
我们加载数据后,下一步就是加载模型。由于我们不会从头开始构建自己的架构,将使用基于现有卷积神经网络架构的预训练模型。
我们将使用已经过预训练的ResNet-50模型。此外,我们基于数据集设置类别的数量。这是执行此操作的代码:
- model = ImageClassifier(backbone="resnet50", num_classes=datamodule.num_classes)
训练模型
加载模型后,现在不妨训练模型。我们需要先初始化Trainer对象。我们将用3个轮次(epoch)训练模型。此外,我们启用GPU以训练模型。这是执行此操作的代码:
- trainer = flash.Trainer(max_epochs=3, gpus=1)
- GPU available: True, used: True TPU available: False, using: 0 TPU cores
初始化对象后,不妨训练模型。为训练模型,我们可以使用一个名为finetune的函数。在函数里面,我们设置模型和数据。此外,我们将训练策略设置为freeze(冻结),这表明我们不想训练特征提取器。换句话说,我们只训练分类器部分。
这是执行此操作的代码:
- trainer.finetune(model, datamodule=datamodule, strategy="freeze")
- LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params ---------------------------------------- 0 | metrics | ModuleDict | 0 1 | backbone | Sequential | 23.5 M 2 | head | Sequential | 4.1 K ---------------------------------------- 57.2 K Trainable params 23.5 M Non-trainable params 23.5 M Total params 94.049 Total estimated model params size (MB)
- Validation sanity check: 0it [00:00, ?it/s]
- Global seed set to 42
- Training: 0it [00:00, ?it/s]
- Validating: 0it [00:00, ?it/s]
- Validating: 0it [00:00, ?it/s]
- Validating: 0it [00:00, ?it/s]
这是评估结果:
从结果中可以看出,我们的模型其准确率达到了约97%。不赖!现在不妨拿几个新数据测试模型。
测试模型
我们将使用针对该模型没有训练过的样本数据。以下是我们将测试模型的样本:
- import matplotlib.pyplot as plt
- from PIL import Image
- fig, ax = plt.subplots(1, 5, figsize=(40,8))
- for i in range(5):
- ax[i].imshow(Image.open(f'cat_and_dog/testing/{i+1}.jpg'))
- plt.show()
为了测试模型,我们可以使用flash库中的predict方法。这是执行此操作的代码:
- model.serializer = Labels()
- predictions = model.predict(["cat_and_dog/testing/1.jpg",
- "cat_and_dog/testing/2.jpg",
- "cat_and_dog/testing/3.jpg",
- "cat_and_dog/testing/4.jpg",
- "cat_and_dog/testing/5.jpg"])
- print(predictions)
- ['dogs', 'dogs', 'cats', 'cats', 'dogs']
从上面的结果可以看出,模型预测了带有正确标签的样本。很好!不妨保存模型以备后用。
保存模型
我们已训练并测试了模型。不妨使用save_checkpoint方法保存模型。这是执行此操作的代码:
- trainer.save_checkpoint("cat_dog_classifier.pt")
如果您想针对其他代码加载模型,可以使用load_from_checkpoint方法。这是执行此操作的代码:
- model = ImageClassifier.load_from_checkpoint("cat_dog_classifier.pt")
三、结语
做得好!您已学习了如何使用Flash构建图像分类器。正如文章开头所说,它只需要几行代码!是不是很酷?
但愿本文可以帮助您根据自己的情况构建自己的深度学习模型。如果您想实施一个更复杂的模型,但愿能开始学习 PyTorch。
原文标题:How to Build An Image Classifier in Few Lines of Code with Flash,作者:Irfan Alghani Khalid
【51CTO译稿,合作站点转载请注明原文译者和出处为51CTO.com】