在 PyTorch 中使用 Datasets 和 DataLoader 自定义数据

人工智能 深度学习
有时候,在处理大数据集时,一次将整个数据加载到内存中变得非常难。唯一的方法是将数据分批加载到内存中进行处理,这需要编写额外的代码来执行此操作。对此,PyTorch 已经提供了 Dataloader 功能。

有时候,在处理大数据集时,一次将整个数据加载到内存中变得非常难。

因此,唯一的方法是将数据分批加载到内存中进行处理,这需要编写额外的代码来执行此操作。对此,PyTorch 已经提供了 Dataloader 功能。

DataLoader

下面显示了 PyTorch 库中DataLoader函数的语法及其参数信息。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
           batch_sampler=None, num_workers=0, collate_fn=None, 
           pin_memory=False, drop_last=False, timeout=0, 
           worker_init_fn=None, *, prefetch_factor=2, 
           persistent_workers=False
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.

几个重要参数

  • dataset:必须首先使用数据集构造 DataLoader 类。
  • Shuffle :是否重新整理数据。
  • Sampler :指的是可选的 torch.utils.data.Sampler 类实例。采样器定义了检索样本的策略,顺序或随机或任何其他方式。使用采样器时应将 Shuffle 设置为 false。
  • Batch_Sampler :批处理级别。
  • num_workers :加载数据所需的子进程数。
  • collate_fn :将样本整理成批次。Torch 中可以进行自定义整理。

加载内置 MNIST 数据集

MNIST 是一个著名的包含手写数字的数据集。下面介绍如何使用DataLoader功能处理 PyTorch 的内置 MNIST 数据集。

import torch 
import matplotlib.pyplot as plt 
from torchvision import datasets, transforms 
  • 1.
  • 2.
  • 3.

上面代码,导入了 torchvision 的torch计算机视觉模块。通常在处理图像数据集时使用,并且可以帮助对图像进行规范化、调整大小和裁剪。

对于 MNIST 数据集,下面使用了归一化技术。

ToTensor()能够把灰度范围从0-255变换到0-1之间。

transform = transforms.Compose([transforms.ToTensor()]) 
  • 1.

下面代码用于加载所需的数据集。使用 PyTorchDataLoader通过给定 batch_size = 64来加载数据。shuffle=True打乱数据。

trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform) 
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True
  • 1.
  • 2.

为了获取数据集的所有图像,一般使用iter函数和数据加载器DataLoader。

dataiter = iter(trainloader) 
images, labels = dataiter.next() 
print(images.shape) 
print(labels.shape) 
plt.imshow(images[1].numpy().squeeze(), cmap='Greys_r'
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.

自定义数据集

下面的代码创建一个包含 1000 个随机数的自定义数据集。

from torch.utils.data import Dataset 
import random 
  
class SampleDataset(Dataset): 
  def __init__(self,r1,r2): 
      randomlist=[] 
      for i in range(120): 
          n = random.randint(r1,r2) 
          randomlist.append(n) 
      self.samples=randomlist  
  
  def __len__(self): 
      return len(self.samples) 
  
  def __getitem__(self,idx): 
      return(self.samples[idx]) 
  
dataset=SampleDataset(1,100) 
dataset[100:120] 
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.

 

在这里插入图片描述

最后,将在自定义数据集上使用 dataloader 函数。将 batch_size 设为 12,并且还启用了num_workers =2 的并行多进程数据加载。

from torch.utils.data import DataLoader 
loader = DataLoader(dataset,batch_size=12, shuffle=True, num_workers=2 ) 
for i, batch in enumerate(loader): 
    print(i, batch) 
  • 1.
  • 2.
  • 3.
  • 4.

 写在后面通过几个示例了解了 PyTorch Dataloader 在将大量数据批量加载到内存中的作用。

 

责任编辑:姜华 来源: Python之王
相关推荐

2022-11-23 15:26:25

Ubuntu程序坞

2009-06-23 11:35:44

JSF的Naviati

2009-11-10 17:12:22

VB.NET自定义类型

2023-11-14 10:05:52

Java开发工具

2023-09-04 15:06:18

Pytorch静态量化动态量化

2023-09-12 13:59:41

OpenAI数据集

2011-06-15 09:24:36

Qt Widget Model

2022-01-14 09:17:13

PythonAPISIX插件

2019-12-25 11:47:27

LinuxFVWM

2010-10-25 16:05:07

oracle自定义函数

2011-06-20 16:54:40

Qt Widget model

2023-12-29 08:01:52

自定义指标模板

2020-07-25 16:33:02

tmuxGitLinux终端

2022-11-29 08:07:23

CSSJavaScript自定义

2017-01-11 10:27:36

Linux终端自定义Bash

2021-12-02 18:05:21

Android Interpolato动画

2022-09-13 15:44:52

VSLook插件

2021-10-28 08:39:22

Node Export自定义 监控

2015-02-12 15:33:43

微信SDK

2020-10-05 21:57:17

GitHub 开源开发
点赞
收藏

51CTO技术栈公众号