PyTorch中的数据集Torchvision和Torchtext

人工智能 深度学习
对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext。之前使用 torchDataLoader类直接加载图像并将其转换为张量。现在结合torchvision和torchtext介绍torch中的内置数据集。

[[421061]]

对于PyTorch加载和处理不同类型数据,官方提供了torchvision和torchtext。

之前使用 torchDataLoader类直接加载图像并将其转换为张量。现在结合torchvision和torchtext介绍torch中的内置数据集

Torchvision 中的数据集

MNIST

MNIST是一个由标准化和中心裁剪的手写图像组成的数据集。它有超过 60,000 张训练图像和 10,000 张测试图像。这是用于学习和实验目的最常用的数据集之一。要加载和使用数据集,使用以下语法导入:torchvision.datasets.MNIST()。

Fashion MNIST

Fashion MNIST数据集类似于MNIST,但该数据集包含T恤、裤子、包包等服装项目,而不是手写数字,训练和测试样本数分别为60,000和10,000。要加载和使用数据集,使用以下语法导入:torchvision.datasets.FashionMNIST()

CIFAR

CIFAR数据集有两个版本,CIFAR10和CIFAR100。CIFAR10 由 10 个不同标签的图像组成,而 CIFAR100 有 100 个不同的类。这些包括常见的图像,如卡车、青蛙、船、汽车、鹿等。

  1. torchvision.datasets.CIFAR10() 
  2. torchvision.datasets.CIFAR100() 

COCO

COCO数据集包含超过 100,000 个日常对象,如人、瓶子、文具、书籍等。这个图像数据集广泛用于对象检测和图像字幕应用。下面是可以加载 COCO 的位置:torchvision.datasets.CocoCaptions()

EMNIST

EMNIST数据集是 MNIST 数据集的高级版本。它由包括数字和字母的图像组成。如果您正在处理基于从图像中识别文本的问题,EMNIST是一个不错的选择。下面是可以加载 EMNIST的位置::torchvision.datasets.EMNIST()

IMAGE-NET

ImageNet 是用于训练高端神经网络的旗舰数据集之一。它由分布在 10,000 个类别中的超过 120 万张图像组成。通常,这个数据集加载在高端硬件系统上,因为单独的 CPU 无法处理这么大的数据集。下面是加载 ImageNet 数据集的类:torchvision.datasets.ImageNet()

Torchtext 中的数据集

IMDB

IMDB是一个用于情感分类的数据集,其中包含一组 25,000 条高度极端的电影评论用于训练,另外 25,000 条用于测试。使用以下类加载这些数据torchtext:torchtext.datasets.IMDB()

WikiText2

WikiText2语言建模数据集是一个超过 1 亿个标记的集合。它是从维基百科中提取的,并保留了标点符号和实际的字母大小写。它广泛用于涉及长期依赖的应用程序。可以从torchtext以下位置加载此数据:torchtext.datasets.WikiText2()

除了上述两个流行的数据集,torchtext库中还有更多可用的数据集,例如 SST、TREC、SNLI、MultiNLI、WikiText-2、WikiText103、PennTreebank、Multi30k 等。

深入查看 MNIST 数据集

MNIST 是最受欢迎的数据集之一。现在我们将看到 PyTorch 如何从 pytorch/vision 存储库加载 MNIST 数据集。让我们首先下载数据集并将其加载到名为 的变量中data_train

  1. from torchvision.datasets import MNIST 
  2.  
  3. # Download MNIST  
  4. data_train = MNIST('~/mnist_data', train=True, download=True
  5.  
  6. import matplotlib.pyplot as plt 
  7.  
  8. random_image = data_train[0][0] 
  9. random_image_label = data_train[0][1] 
  10.  
  11. # Print the Image using Matplotlib 
  12. plt.imshow(random_image) 
  13. print("The label of the image is:", random_image_label) 

DataLoader加载MNIST

下面我们使用DataLoader该类加载数据集,如下所示。

  1. import torch 
  2. from torchvision import transforms 
  3.  
  4. data_train = torch.utils.data.DataLoader( 
  5.     MNIST( 
  6.           '~/mnist_data', train=True, download=True,  
  7.           transform = transforms.Compose([ 
  8.               transforms.ToTensor() 
  9.           ])), 
  10.           batch_size=64, 
  11.           shuffle=True 
  12.           ) 
  13.  
  14. for batch_idx, samples in enumerate(data_train): 
  15.       print(batch_idx, samples) 

CUDA加载

我们可以启用 GPU 来更快地训练我们的模型。现在让我们使用CUDA加载数据时可以使用的(GPU 支持 PyTorch)的配置。

  1. device = "cuda" if torch.cuda.is_available() else "cpu" 
  2. kwargs = {'num_workers': 1, 'pin_memory'True} if device=='cuda' else {} 
  3.  
  4. train_loader = torch.utils.data.DataLoader( 
  5.   torchvision.datasets.MNIST('/files/', train=True, download=True), 
  6.   batch_size=batch_size_train, **kwargs) 
  7.  
  8. test_loader = torch.utils.data.DataLoader( 
  9.   torchvision.datasets.MNIST('files/', train=False, download=True), 
  10.   batch_size=batch_size, **kwargs) 

ImageFolder

ImageFolder是一个通用数据加载器类torchvision,可帮助加载自己的图像数据集。处理一个分类问题并构建一个神经网络来识别给定的图像是apple还是orange。要在 PyTorch 中执行此操作,第一步是在默认文件夹结构中排列图像,如下所示:

  1. root 
  2. ├── orange 
  3. │   ├── orange_image1.png 
  4. │   └── orange_image1.png 
  5. ├── apple 
  6. │   └── apple_image1.png 
  7. │   └── apple_image2.png 
  8. │   └── apple_image3.png 

可以使用ImageLoader该类加载所有这些图像。

  1. torchvision.datasets.ImageFolder(root, transform) 

transforms

PyTorch 转换定义了简单的图像转换技术,可将整个数据集转换为独特的格式。

如果是一个包含不同分辨率的不同汽车图片的数据集,在训练时,我们训练数据集中的所有图像都应该具有相同的分辨率大小。如果我们手动将所有图像转换为所需的输入大小,则很耗时,因此我们可以使用transforms;使用几行 PyTorch 代码,我们数据集中的所有图像都可以转换为所需的输入大小和分辨率。

现在让我们加载 CIFAR10torchvision.datasets并应用以下转换:

  • 将所有图像调整为 32×32
  • 对图像应用中心裁剪变换
  • 将裁剪后的图像转换为张量
  • 标准化图像
  1. import torch 
  2. import torchvision 
  3. import torchvision.transforms as transforms 
  4. import matplotlib.pyplot as plt 
  5. import numpy as np 
  6.  
  7. transform = transforms.Compose([ 
  8.     # resize 32×32 
  9.     transforms.Resize(32), 
  10.     # center-crop裁剪变换 
  11.     transforms.CenterCrop(32), 
  12.     # to-tensor 
  13.     transforms.ToTensor(), 
  14.     # normalize 标准化 
  15.     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 
  16. ]) 
  17.  
  18. trainset = torchvision.datasets.CIFAR10(root='./data', train=True
  19.                                         download=True, transform=transform) 
  20. trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
  21.                                           shuffle=False

在 PyTorch 中创建自定义数据集

下面将创建一个由数字和文本组成的简单自定义数据集。需要封装Dataset 类中的__getitem__()和__len__()方法。

  • __getitem__()方法通过索引返回数据集中的选定样本。
  • __len__()方法返回数据集的总大小。

下面是曾经封装FruitImagesDataset数据集的代码,基本是比较好的 PyTorch 中创建自定义数据集的模板。

  1. import os 
  2. import numpy as np 
  3. import cv2 
  4. import torch 
  5. import matplotlib.patches as patches 
  6. import albumentations as A 
  7. from albumentations.pytorch.transforms import ToTensorV2 
  8. from matplotlib import pyplot as plt 
  9. from torch.utils.data import Dataset 
  10. from xml.etree import ElementTree as et 
  11. from torchvision import transforms as torchtrans 
  12.  
  13. class FruitImagesDataset(torch.utils.data.Dataset): 
  14.     def __init__(self, files_dir, width, height, transforms=None): 
  15.         self.transforms = transforms 
  16.         self.files_dir = files_dir 
  17.         self.height = height 
  18.         self.width = width 
  19.  
  20.  
  21.         self.imgs = [image for image in sorted(os.listdir(files_dir)) 
  22.                      if image[-4:] == '.jpg'
  23.  
  24.         self.classes = ['_','apple''banana''orange'
  25.  
  26.     def __getitem__(self, idx): 
  27.  
  28.         img_name = self.imgs[idx] 
  29.         image_path = os.path.join(self.files_dir, img_name) 
  30.  
  31.         # reading the images and converting them to correct size and color 
  32.         img = cv2.imread(image_path) 
  33.         img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) 
  34.         img_res = cv2.resize(img_rgb, (self.width, self.height), cv2.INTER_AREA) 
  35.         # diving by 255 
  36.         img_res /= 255.0 
  37.  
  38.         # annotation file 
  39.         annot_filename = img_name[:-4] + '.xml' 
  40.         annot_file_path = os.path.join(self.files_dir, annot_filename) 
  41.  
  42.         boxes = [] 
  43.         labels = [] 
  44.         tree = et.parse(annot_file_path) 
  45.         root = tree.getroot() 
  46.  
  47.         # cv2 image gives size as height x width 
  48.         wt = img.shape[1] 
  49.         ht = img.shape[0] 
  50.  
  51.         # box coordinates for xml files are extracted and corrected for image size given 
  52.         for member in root.findall('object'): 
  53.             labels.append(self.classes.index(member.find('name').text)) 
  54.  
  55.             # bounding box 
  56.             xmin = int(member.find('bndbox').find('xmin').text) 
  57.             xmax = int(member.find('bndbox').find('xmax').text) 
  58.  
  59.             ymin = int(member.find('bndbox').find('ymin').text) 
  60.             ymax = int(member.find('bndbox').find('ymax').text) 
  61.  
  62.             xmin_corr = (xmin / wt) * self.width 
  63.             xmax_corr = (xmax / wt) * self.width 
  64.             ymin_corr = (ymin / ht) * self.height 
  65.             ymax_corr = (ymax / ht) * self.height 
  66.  
  67.             boxes.append([xmin_corr, ymin_corr, xmax_corr, ymax_corr]) 
  68.  
  69.         # convert boxes into a torch.Tensor 
  70.         boxes = torch.as_tensor(boxes, dtype=torch.float32) 
  71.  
  72.         # getting the areas of the boxes 
  73.         area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 
  74.  
  75.         # suppose all instances are not crowd 
  76.         iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64) 
  77.  
  78.         labels = torch.as_tensor(labels, dtype=torch.int64) 
  79.  
  80.         target = {} 
  81.         target["boxes"] = boxes 
  82.         target["labels"] = labels 
  83.         target["area"] = area 
  84.         target["iscrowd"] = iscrowd 
  85.         # image_id 
  86.         image_id = torch.tensor([idx]) 
  87.         target["image_id"] = image_id 
  88.  
  89.         if self.transforms: 
  90.             sample = self.transforms(image=img_res, 
  91.                                      bboxes=target['boxes'], 
  92.                                      labels=labels) 
  93.  
  94.             img_res = sample['image'
  95.             target['boxes'] = torch.Tensor(sample['bboxes']) 
  96.         return img_res, target 
  97.     def __len__(self): 
  98.         return len(self.imgs) 
  99.  
  100. def get_transform(train): 
  101.     if train: 
  102.         return A.Compose([ 
  103.             A.HorizontalFlip(0.5), 
  104.             ToTensorV2(p=1.0) 
  105.         ], bbox_params={'format''pascal_voc''label_fields': ['labels']}) 
  106.     else
  107.         return A.Compose([ 
  108.             ToTensorV2(p=1.0) 
  109.         ], bbox_params={'format''pascal_voc''label_fields': ['labels']}) 
  110.  
  111. files_dir = '../input/fruit-images-for-object-detection/train_zip/train' 
  112. test_dir = '../input/fruit-images-for-object-detection/test_zip/test' 
  113.  
  114. dataset = FruitImagesDataset(train_dir, 480, 480) 

 

责任编辑:姜华 来源: Python之王
相关推荐

2023-04-07 07:29:54

Torchvisio计算机视觉

2023-12-18 10:41:28

深度学习NumPyPyTorch

2021-12-13 09:14:06

清单管理数据集

2019-06-19 09:13:29

机器学习中数据集深度学习

2021-09-10 10:26:45

PyTorch数据集S3 Plugin

2010-04-27 13:21:58

Oracle数据字符集

2020-07-15 13:51:48

TensorFlow数据机器学习

2023-12-01 16:23:52

大数据人工智能

2021-09-03 06:46:34

SQL分组集功能

2023-07-28 09:54:14

SQL数据Excel

2021-11-09 08:48:48

Python开源项目

2020-10-27 09:37:43

PyTorchTensorFlow机器学习

2020-06-24 07:53:03

机器学习技术人工智能

2022-09-16 00:11:45

PyTorch神经网络存储

2020-10-05 21:57:17

GitHub 开源开发

2020-04-29 13:40:32

数据集数据科学冠状病毒

2009-08-03 14:39:25

Asp.Net函数集

2010-06-17 09:29:32

SQLServer 2

2020-10-15 11:22:34

PyTorchTensorFlow机器学习

2021-01-14 21:40:40

机器学习计算机视觉图像数据集
点赞
收藏

51CTO技术栈公众号