pytorch中使用Dataset和DataLoader创建自定义数据集 入门 您所在的位置:网站首页 创建和使用类 pytorch中使用Dataset和DataLoader创建自定义数据集 入门

pytorch中使用Dataset和DataLoader创建自定义数据集 入门

2024-07-11 21:15| 来源: 网络整理| 查看: 265

介绍

pytorch中,我们可以使用torch.utils.data.DataLoader和torch.utils.data.Dataset加载数据集,具体来说,可以简单理解为Dataset是数据集,他提供数据与索引之间的映射,同时也要有标签。而DataLoader是将Dataset中的数据迭代提取出来,从而能够提供给模型。 所以,具体流程是,我们应该先按照要求先建立一个Dataset,之后再建立一个DataLoader,然后就可以用了。 pytorch中有很多现成的数据集,我们下载就可以使用。但是更多时候我们要建立自己的数据集,我也是入门,所以先建立一个带标签的图像数据集。

参考 DATASETS & DATALOADERS两文读懂PyTorch中Dataset与DataLoader(一)打造自己的数据集从0开始撸代码–手把手教你搭建AlexNet网络模型训练自己的数据集(猫狗分类 建立Dataset

我们可以继承torch.utils.data.Dataset类,必须要重写__init__, __len__, 和 __getitem__这三个函数。其中 __len__能够返回我们数据集中的数据个数,__getitem__能够根据索引返回数据。

前提

我们有一个文件夹,里面有很多猫、狗和汽车的照片,此外有一个csv文件,里面是每张照片对应的类别,也就是标签。我们根据这个照片文件夹和csv文件,来建立我们的带标签数据集。

对于图片文件夹:0——29张图片为猫,30——59张图片为狗,其他为汽车。 图片文件夹对于标签csv文件,每一行中首先是图片名,然后是类别。其中0代表猫,1代表狗,2代表汽车。如下图: 标签文件 具体代码 import os from torchvision.io import read_image import pandas as pd from torch.utils.data import Dataset import matplotlib.pyplot as plt import numpy as np class myImageDataset(Dataset): def __init__(self, img_dir, img_label_dir, transform=None): super().__init__() self.img_dir = img_dir self.img_labels = pd.read_csv(img_label_dir) # 这是一个dataframe,0是文件名,1是类别 self.transform = transform def __len__(self): return len(self.img_labels) # 数据集长度 def __getitem__(self, index): # 拼接得到图片文件路径 # 例如img_dir为'D:/curriculum/2022learning/learnning_dataset/data/' # img_labels.iloc[index, 0]为5.jpg # 那么img_path为'D:/curriculum/2022learning/learnning_dataset/data/5.jpg' img_path = os.path.join(self.img_dir + self.img_labels.iloc[index, 0]) image = read_image(img_path) # tensor类型 label = self.img_labels.iloc[index, 1] if self.transform is not None: image = self.transform(image) # 对图片进行某些变换 return image, label

代码中都有注释。

__init__()

类的初始化函数,其中img_dir为图片文件夹的根目录,img_label_dir为标签文件路径,transform为对数据项进行的变换。

__len__()

返回数据集长度。

__getitem__()

根据index,返回其在数据集中对应的数据和标签。

验证

通过如下代码,我们具体输出一张图片:

# 把图片对应的tensor调整维度,并显示 def tensorToimg(img_tensor): img = img_tensor.numpy() img = np.transpose(img_tensor, [1, 2, 0]) plt.imshow(img) label_dic = {0: 'cat', 1: 'dog', 2: 'car'} label_path = 'D:/curriculum/2022learning/learnning_dataset/labels.csv' img_root_path = 'D:/curriculum/2022learning/learnning_dataset/data/' dataset = myImageDataset(img_root_path, label_path) image, label = dataset.__getitem__(33) print(image.shape) print(label_dic[label]) tensorToimg(image)

结果 可以看到,数据集中,图片变为tensor,维度为[通道数,长,宽]。

DataLoader

之后就可以使用DataLoader对刚刚创建的数据集不断取出样本了。不再赘述。

dataloader = DataLoader(dataset, batch_size=5, shuffle=True)

这样,我们就建立了一个dataLoader。接下来我们输出一下看看:

for imgs, labels in dataloader: print(imgs.shape) print(labels) break

但是这里报错:stack expects each tensor to be equal size, but got [3, 268, 320] at entry 0 and [3, 480, 370] at ...,查询得知是数据集中图片大小不一,而这时Dataset中定义的参数transfom就派上了用场。我们让每张图片的大小都是224*224。

from torch.utils.data import DataLoader from torchvision import transforms transform = transforms.Resize((224, 224)) dataset = myImageDataset(img_root_path, label_path, transform) dataloader = DataLoader(dataset, batch_size=5, shuffle=True) for imgs, labels in dataloader: print(imgs.shape) print(labels) break

结果为:

torch.Size([5, 3, 224, 224]) tensor([0, 2, 2, 2, 1])

由于batch_size是5,而每个图片的形状为[3, 224, 224],因此一个batch的数据形状为:[5, 3, 224, 224]。

其他使用DataLoader的方法 for index, (imgs, labels) in enumerate(dataloader): print(index) print(imgs.shape) print(labels) break

结果为:

0 torch.Size([5, 3, 224, 224]) tensor([1, 2, 0, 1, 1]) imgs, label = next(iter(dataloader)) print(imgs.shape) print(labels)

结果为:

torch.Size([5, 3, 224, 224]) tensor([1, 2, 0, 1, 1])

得到了一批的图片和对应的标签,我们就能将其输入到模型中,并使用标签和预测结果计算损失。



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有