pytorch中使用Dataset和DataLoader创建自定义数据集 入门 | 您所在的位置:网站首页 › 创建和使用类 › pytorch中使用Dataset和DataLoader创建自定义数据集 入门 |
介绍
pytorch中,我们可以使用torch.utils.data.DataLoader和torch.utils.data.Dataset加载数据集,具体来说,可以简单理解为Dataset是数据集,他提供数据与索引之间的映射,同时也要有标签。而DataLoader是将Dataset中的数据迭代提取出来,从而能够提供给模型。 所以,具体流程是,我们应该先按照要求先建立一个Dataset,之后再建立一个DataLoader,然后就可以用了。 pytorch中有很多现成的数据集,我们下载就可以使用。但是更多时候我们要建立自己的数据集,我也是入门,所以先建立一个带标签的图像数据集。 参考 DATASETS & DATALOADERS两文读懂PyTorch中Dataset与DataLoader(一)打造自己的数据集从0开始撸代码–手把手教你搭建AlexNet网络模型训练自己的数据集(猫狗分类 建立Dataset我们可以继承torch.utils.data.Dataset类,必须要重写__init__, __len__, 和 __getitem__这三个函数。其中 __len__能够返回我们数据集中的数据个数,__getitem__能够根据索引返回数据。 前提我们有一个文件夹,里面有很多猫、狗和汽车的照片,此外有一个csv文件,里面是每张照片对应的类别,也就是标签。我们根据这个照片文件夹和csv文件,来建立我们的带标签数据集。 对于图片文件夹:0——29张图片为猫,30——59张图片为狗,其他为汽车。![]() ![]() 代码中都有注释。 __init__()类的初始化函数,其中img_dir为图片文件夹的根目录,img_label_dir为标签文件路径,transform为对数据项进行的变换。 __len__()返回数据集长度。 __getitem__()根据index,返回其在数据集中对应的数据和标签。 验证通过如下代码,我们具体输出一张图片: # 把图片对应的tensor调整维度,并显示 def tensorToimg(img_tensor): img = img_tensor.numpy() img = np.transpose(img_tensor, [1, 2, 0]) plt.imshow(img) label_dic = {0: 'cat', 1: 'dog', 2: 'car'} label_path = 'D:/curriculum/2022learning/learnning_dataset/labels.csv' img_root_path = 'D:/curriculum/2022learning/learnning_dataset/data/' dataset = myImageDataset(img_root_path, label_path) image, label = dataset.__getitem__(33) print(image.shape) print(label_dic[label]) tensorToimg(image)
之后就可以使用DataLoader对刚刚创建的数据集不断取出样本了。不再赘述。 dataloader = DataLoader(dataset, batch_size=5, shuffle=True)这样,我们就建立了一个dataLoader。接下来我们输出一下看看: for imgs, labels in dataloader: print(imgs.shape) print(labels) break但是这里报错:stack expects each tensor to be equal size, but got [3, 268, 320] at entry 0 and [3, 480, 370] at ...,查询得知是数据集中图片大小不一,而这时Dataset中定义的参数transfom就派上了用场。我们让每张图片的大小都是224*224。 from torch.utils.data import DataLoader from torchvision import transforms transform = transforms.Resize((224, 224)) dataset = myImageDataset(img_root_path, label_path, transform) dataloader = DataLoader(dataset, batch_size=5, shuffle=True) for imgs, labels in dataloader: print(imgs.shape) print(labels) break结果为: torch.Size([5, 3, 224, 224]) tensor([0, 2, 2, 2, 1])由于batch_size是5,而每个图片的形状为[3, 224, 224],因此一个batch的数据形状为:[5, 3, 224, 224]。 其他使用DataLoader的方法 for index, (imgs, labels) in enumerate(dataloader): print(index) print(imgs.shape) print(labels) break结果为: 0 torch.Size([5, 3, 224, 224]) tensor([1, 2, 0, 1, 1]) imgs, label = next(iter(dataloader)) print(imgs.shape) print(labels)结果为: torch.Size([5, 3, 224, 224]) tensor([1, 2, 0, 1, 1])得到了一批的图片和对应的标签,我们就能将其输入到模型中,并使用标签和预测结果计算损失。 |
CopyRight 2018-2019 实验室设备网 版权所有 |