您现在的位置是:网站首页> 编程资料编程资料

PyTorch中torch.utils.data.Dataset的介绍与实战_python_

2023-05-26 840人已围观

简介 PyTorch中torch.utils.data.Dataset的介绍与实战_python_

一、前言

训练模型一般都是先处理 数据的输入问题 和 预处理问题 。Pytorch提供了几个有用的工具:torch.utils.data.Dataset 类和 torch.utils.data.DataLoader 类 。

流程是先把原始数据转变成 torch.utils.data.Dataset 类,随后再把得到的 torch.utils.data.Dataset 类当作一个参数传递给 torch.utils.data.DataLoader 类,得到一个数据加载器,这个数据加载器每次可以返回一个 Batch 的数据供模型训练使用。

在 pytorch 中,提供了一种十分方便的数据读取机制,即使用 torch.utils.data.Dataset 与 Dataloader 组合得到数据迭代器。在每次训练时,利用这个迭代器输出每一个 batch 数据,并能在输出时对数据进行相应的预处理或数据增广操作。

本文我们主要介绍对 torch.utils.data.Dataset 的理解,对 Dataloader 的介绍请参考我的另一篇文章:【PyTorch】torch.utils.data.DataLoader 简单介绍与使用

在本文的最后将给出 torch.utils.data.Dataset 与 Dataloader 结合使用处理数据的实战代码。

二、torch.utils.data.Dataset 是什么

1. 干什么用的?

  1. pytorch 提供了一个数据读取的方法,其由两个类构成:torch.utils.data.Dataset 和 DataLoader。
  2. 如果我们要自定义自己读取数据的方法,就需要继承类 torch.utils.data.Dataset ,并将其封装到DataLoader 中。
  3. torch.utils.data.Dataset 是一个 类 Dataset 。通过重写定义在该类上的方法,我们可以实现多种数据读取及数据预处理方式。

2. 长什么样子?

torch.utils.data.Dataset 的源码:

class Dataset(object): """An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """ def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other]) 

注释翻译:

表示一个数据集的抽象类。

所有其他数据集都应该对其进行子类化。 所有子类都应该重写提供数据集大小的 __len__ 和 __getitem__ ,支持从 0 到 len(self) 独占的整数索引。

理解:

就是说,Dataset 是一个 数据集 抽象类,它是其他所有数据集类的父类(所有其他数据集类都应该继承它),继承时需要重写方法 __len__ 和 __getitem__ , __len__ 是提供数据集大小的方法, __getitem__ 是可以通过索引号找到数据的方法。

三、通过继承 torch.utils.data.Dataset 定义自己的数据集类

torch.utils.data.Dataset 是代表自定义数据集的抽象类,我们可以定义自己的数据类抽象这个类,只需要重写__len__和__getitem__这两个方法就可以。

要自定义自己的 Dataset 类,至少要重载两个方法:__len__, __getitem__

  1. __len__返回的是数据集的大小
  2. __getitem__实现索引数据集中的某一个数据

下面将简单实现一个返回 torch.Tensor 类型的数据集:

from torch.utils.data import Dataset import torch class TensorDataset(Dataset): # TensorDataset继承Dataset, 重载了__init__, __getitem__, __len__ # 实现将一组Tensor数据对封装成Tensor数据集 # 能够通过index得到数据集的数据,能够通过len,得到数据集大小 def __init__(self, data_tensor, target_tensor): self.data_tensor = data_tensor self.target_tensor = target_tensor def __getitem__(self, index): return self.data_tensor[index], self.target_tensor[index] def __len__(self): return self.data_tensor.size(0) # size(0) 返回当前张量维数的第一维 # 生成数据 data_tensor = torch.randn(4, 3) # 4 行 3 列,服从正态分布的张量 print(data_tensor) target_tensor = torch.rand(4) # 4 个元素,服从均匀分布的张量 print(target_tensor) # 将数据封装成 Dataset (用 TensorDataset 类) tensor_dataset = TensorDataset(data_tensor, target_tensor) # 可使用索引调用数据 print('tensor_data[0]: ', tensor_dataset[0]) # 可返回数据len print('len os tensor_dataset: ', len(tensor_dataset)) 

输出结果:

tensor([[ 0.8618,  0.4644, -0.5929],
        [ 0.9566, -0.9067,  1.5781],
        [ 0.3943, -0.7775,  2.0366],
        [-1.2570, -0.3859, -0.3542]])
tensor([0.1363, 0.6545, 0.4345, 0.9928])
tensor_data[0]:  (tensor([ 0.8618,  0.4644, -0.5929]), tensor(0.1363))
len os tensor_dataset:  4

四、为什么要定义自己的数据集类?

因为我们可以通过定义自己的数据集类并重写该类上的方法 实现多种多样的(自定义的)数据读取方式。

比如,我们重写 __init__ 实现用 pd.read_csv 读取 csv 文件:

from torch.utils.data import Dataset import pandas as pd # 这个包用来读取CSV数据 # 继承Dataset,定义自己的数据集类 mydataset class mydataset(Dataset): def __init__(self, csv_file): # self 参数必须,其他参数及其形式随程序需要而不同,比如(self,*inputs) self.csv_data = pd.read_csv(csv_file) def __len__(self): return len(self.csv_data) def __getitem__(self, idx): data = self.csv_data.values[idx] return data data = mydataset('spambase.csv') print(data[3]) print(len(data)) 

输出结果:

[0.000e+00 0.000e+00 0.000e+00 0.000e+00 6.300e-01 0.000e+00 3.100e-01
 6.300e-01 3.100e-01 6.300e-01 3.100e-01 3.100e-01 3.100e-01 0.000e+00
 0.000e+00 3.100e-01 0.000e+00 0.000e+00 3.180e+00 0.000e+00 3.100e-01
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 1.370e-01 0.000e+00 1.370e-01 0.000e+00 0.000e+00 3.537e+00 4.000e+01
 1.910e+02 1.000e+00]
4601

要点:

  1. 自己定义的 dataset 类需要继承 Dataset。
  2. 需要实现必要的魔法方法:

在 __init__ 方法里面进行 读取数据文件 。

在 __getitem__ 方法里支持通过下标访问数据。

在 __len__ 方法里返回自定义数据集的大小,方便后期遍历。

五、实战:torch.utils.data.Dataset + Dataloader 实现数据集读取和迭代

实例 1

数据集 spambase.csv 用的是 UCI 机器学习存储库里的垃圾邮件数据集,它一条数据有57个特征和1个标签。

import torch.utils.data as Data import pandas as pd # 这个包用来读取CSV数据 import torch # 继承Dataset,定义自己的数据集类 mydataset class mydataset(Data.Dataset): def __init__(self, csv_file): # self 参数必须,其他参数及其形式随程序需要而不同,比如(self,*inputs) data_csv = pd.DataFrame(pd.read_csv(csv_file)) # 读数据 self.csv_data = data_csv.drop(axis=1, columns='58', inplace=False) # 删除最后一列标签 def __len__(self): return len(self.csv_data) def __getitem__(self, idx): data = self.csv_data.values[idx] return data data = mydataset('spambase.csv') x = torch.tensor(data[:5]) # 前五个数据 y = torch.tensor([1, 1, 1, 1, 1]) # 标签 torch_dataset = Data.TensorDataset(x, y) # 对给定的 tensor 数据,将他们包装成 dataset loader = Data.DataLoader( # 从数据库中每次抽出batch size个样本 dataset = torch_dataset, # torch TensorDataset format batch_size = 2, # mini batch size shuffle=True, # 要不要打乱数据 (打乱比较好) num_workers=2, # 多线程来读数据 ) def show_batch(): for step, (batch_x, batch_y) in enumerate(loader): print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y)) show_batch() 

输出结果:

steop:0, batch_x:tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.3000e-01, 0.0000e+00,
         3.1000e-01, 6.3000e-01, 3.1000e-01, 6.3000e-01, 3.1000e-01, 3.1000e-01,
         3.1000e-01, 0.0000e+00, 0.0000e+00, 3.1000e-01, 0.0000e+00, 0.0000e+00,
         3.1800e+00, 0.0000e+00, 3.1000e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 1.3500e-01, 0.0000e+00, 1.3500e-01, 0.0000e+00, 0.0000e+00,
         3.5370e+00, 4.0000e+01, 1.9100e+02],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.3000e-01, 0.0000e+00,
         3.1000e-01, 6.3000e-01, 3.1000e-01, 6.3000e-01, 3.1000e-01, 3.1000e-01,
         3.1000e-01, 0.0000e+00, 0.0000e+00, 3.1000e-01, 0.0000e+00, 0.0000e+00,
         3.1800e+00, 0.0000e+00, 3.1000e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 1.3700e-01, 0.0000e+00, 1.3700e-01, 0.0000e+00, 0.0000e+00,
         3.5370e+00, 4.0000e+01, 1.9100e+02]], dtype=torch.float64), batch_y:tensor([1, 1])
steop:1, batch_x:tensor([[2.1000e-01, 2.8000e-01, 5.0000e-01, 0.0000e+00, 1.4000e-01, 2.8000e-01,
         2.1000e-01, 7.0000e-02, 0.0000e+00, 9.4000e-01, 2.1000e-01, 7.9000e-01,
         6.5000e-01, 2.1000e-01, 1.4000e-01, 1.4000e-01, 7.0000e-02, 2.8000e-01,
         3.4700e+00, 0.0000e+00, 1.5900e+00, 0.0000e+00, 4.3000e-01, 4.3000e-01,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         7.0000e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 1.3200e-01, 0.0000e+00, 3.7200e-01, 1.8000e-01, 4.8000e-02,
         5.1140e+00, 1.0100e+02, 1.0280e+03],
        [6.0000e-02, 0.0000e+00, 7.1000e-01, 0.0000e+00, 1.2300e+00, 1.9000e-01,
         1.9000e-01, 1.2000e-01, 6.4000e-01, 2.5000e-01, 3.8000e-01, 4.5000e-01,
         1.2000e-01, 0.0000e+00, 1.7500e+00, 6.0000e-02, 6.0000e-02,

-六神源码网