详解PyTorch数据读取机制Dataloader与Dataset 您所在的位置:网站首页 硝化池是好氧池吗为什么 详解PyTorch数据读取机制Dataloader与Dataset

详解PyTorch数据读取机制Dataloader与Dataset

2023-05-07 03:23| 来源: 网络整理| 查看: 265

来源:投稿 作者:阿克西

编辑:学姐

本章主要讲述数据模块,如何从硬盘中读取数据,对数据进行预处理、数据增强,转换为张量的形式输入到模型之中。

1.模块简介

本节主要学习数据模块当中的数据读取,数据模块通常还会分为四个子模块,数据收集、数据划分、数据读取、数据预处理。

● 数据收集:收集原始样本和标签,如Img和Label。

● 数据划分:划分成训练集train,用来训练模型;验证集valid,验证模型是否过拟合,挑选还没有过拟合的时候的模型;测试集test,测试挑选出来的模型的性能。

● 数据读取:PyTorch中数据读取的核心是Dataloader。Dataloader分为Sampler和DataSet两个子模块。Sampler的功能是生成索引,即样本序号;DataSet的功能是根据索引读取样本和对应的标签。

● 数据预处理:数据的中心和,标准化,旋转,翻转等,在PyTorch中是通过transforms实现的。

这里主要学习第三个子模块中的Dataloader和Dataset。

2.DataLoader与Dataset

DataLoader和Dataset是pytorch中数据读取的核心。

2.1 DataLoader

功能:构建可迭代的数据装载器,每一次for循环就是从DataLoader中加载一个batchsize数据。

● dataset:Dataset类,决定数据从哪读取及如何读取

● batchsize:批大小

● num_works:是否多进程读取数据

● shuffle:每个epoch是否乱序

● drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

概念辨析:

● epoch:所有训练样本都已输入到模型中,称为一个epoch

● iteration:一批样本输入到模型中,称之为一个iteration

● batchsize:批大小,决定一个epoch中有多少个iteration

样本总数:80,batchsize:8 (样本能被batchsize整除)

● 1(epoch) = 10(iteration)

样本总数:87,batchsize=8 (样本不能被batchsize整除)

● drop_last = True:1(epoch) = 10(iteration)

● drop_last = False:1(epoch)= 11(iteration)

2.2 Dataset

功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且必须复写__getitem__()。

● Dataset:用来定义数据从哪里读取,以及如何读取的问题

● getitem:接收一个索引,返回一个样本

3.人民币二分类

要求:对第四套人民币1元和10元进行二分类,将人民币看作自变量x,类别看作因变量y,模型就是将自变量x映射到因变量y。

下面对人民币二分类的数据进行读取,从三个方面了解pytorch的读取机制,分别为读哪些数据、从哪读数据、怎么读数据。

读哪些数据:在每一个iteration的时候应该读取哪些数据,每一个iteration读取一个batch大小的数据,假如有80个样本,那么从80个样本中读取8个样本,那么应该读取哪8个样本。

从哪读数据 :在硬盘当中,我们应该怎么找到对应的数据,在哪里设置参数。

怎么读数据 :从代码中学习。

3.1 数据集划分

划分好的数据集:

3.2 人民币分类模型训练3.2.1 导入包与参数设置3.2.2 🔥Dataset类参数一:从哪读数据,设置硬盘中的路径3.2.3 🔥Dataset类参数二:数据预处理transform

transforms.Compose将一系列数据增强方法进行有序的组合,依次按照顺序对图像进行处理。

3.2.4 🔥构建Dataset实例

Dataset必须是用户自己构建的,在Dataset中会传入两个主要参数,一个是data_dir,表示数据集的路径,即从哪读数据;第二个参数是transform,表示数据预处理。代码中构建了两个Dataset实例,一个用于训练,一个用于验证。

3.2.5 🔥构建DataLoader实例

有了Dataset就可以构建数据迭代器DataLoader,DataLoader传入的第一个参数是Dataset,也就是RMBDataset实例;第二个参数是batch_size;在训练集中的多了一个参数shuffle=True,作用是每一个epoch中样本都是乱序的。

3.2.6 模型、损失函数、优化器3.2.7 开始训练

设置好数据、模型、损失函数和优化器之后,就可以进行模型的训练。

模型训练以epoch为周期,代码中先进行epoch的主循环,在每一个epoch当中会有多个iteration的训练,在每一个iteration当中去训练模型,每一次读取一个batch_size大小的数据,然后输入到模型中,进行前向传播,反向传播获取梯度,更新权值,接着统计分类准确率,打印训练信息。在每一个epoch会进行验证集的测试,通过验证集来观察模型是否过拟合。

输出结果

3.3 RMBDataset类3.4 断点调试

现在了解一下上面代码中RMBDataset中的具体实现。

pycharm小技巧:按住Ctrl,然后单击函数名或者类名就可以跳转到具体函数实现的位置。

在训练模型时,数据的获取是通过for循环获取的,从DataLoader迭代器中不停地去获取一个batchsize大小的数据。

1、下面通过代码的调试观察pytorch是如何读取数据的,在该处设置断点,然后执行debug。

点击step into功能键,跳转到对应的函数中,发现是跳到了dataloader.py文件中的__iter__()函数;具体如下所示:

这段代码是一个if的判断语句,其功能是判断是否采用多进程;如果采用多进程,有多进程的读取机制;如果是单进程,有单进程的读取机制;这里以单进程进行演示。

2、单击两次step into功能键

单进程当中,最主要的是__next__()函数,在next中会获取index和data,回想一下数据读取中的三个问题,第一个问题是读哪些数据;__next__()函数就告诉我们,在每一个iteration当中读取哪些数据。

现在将光标对准_next_data函数中的第一行index=self._next_index(),点击功能区中的run to cursor,然后程序就会运行到这一行,点击功能区中的step into,进入到_next_index()函数中了解是怎么获得数据的index的;之后代码会跳到下面的代码中:

再点击一下step into就进入了sampler.py文件中,sampler是一个采样器,其功能是告诉我们每一个batch_size应该读取哪些数据;

点击两次step out功能键

点击step over功能键,执行上面这段代码中的:

就可以挑选出一个Iteration中的index,batch_size的值是16,则index列表长度为16:

有了index之后,将index输入到Dataset当中去获取data,代码中会进入dataset_fetcher.fetch()函数。

3、点击功能区中的step_into,进入到fetch.py文件的_MapDatasetFetcher()类当中,在这个类里面实现了具体的数据读取,具体代码如下。代码中调用了dataset,通过输入一个索引idx返回一个data,将一系列的data拼接成一个list。

点击step into查看一下这个过程,代码跳转到自定义dataset类RMBdataset()中的__getitem__()函数中,所以dataset最重要最核心的就是__getitem__()函数;

这里已经实现了data_info()函数,对数据进行初步的读取,可以得到图片的路径和标签;然后通过Image.open来读取数据,这就实现了一个数据的读取,标签的获取。

之后点击step_out跳出该函数,会返回fetch()函数中;

在fetch()函数return的时候会进入一个collate_fn(),它是数据的整理器,会将我们读取到的16个数据整理出一个batch的形式;得到数据和标签。

将光标放在return self.collate_fn(data) 处,点击run to cursor执行到当前位置,之后点击step over返回到单进程,点击step over,执行到下述代码,发现data已被打包,第一个元素是图像,第二个元素是标签。

点击多次step out返回到最初训练模型读取数据的位置,执行step over可以发现循环中的data已被打包,第一个元素是图像,第二个元素是标签。

3.5 总结

通过以上的分析,可以回答一开始提出的数据读取的三个问题:

1、读哪些数据?

答:从代码中可以发现,index是从sampler.py中输出的,所以读哪些数据是由sampler得到的;

2、从哪读数据?

答:从代码中看,是从Dataset中的参数data_dir告诉我们pytorch是从硬盘中的哪一个文件夹获取数据。

3、怎么读数据?

答:从代码中可以发现,pytorch是从Dataset的getitem()中具体实现的,根据索引去读取数据。

Dataloader读取数据很复杂,需要经过四五个函数的跳转才能最终读取数据

为了简单,将整个跳转过程以流程图进行表示,通过流程图对数据读取机制有一个简单的认识。

简单描述一下流程图:

首先在for循环中去使用DataLoader;

进入DataLoader之后是否采用多进程进入单进程或者多进程的DataLoaderlter;

进入DataLoaderIter之后会使用sampler去获取Index;

拿到索引之后传输到DatasetFetcher;

在DatasetFetcher中会调用Dataset,Dataset根据给定的Index,在getitem中从硬盘里面去读取实际的Img和Label;

读取了一个batch_size的数据之后,通过一个collate_fn将数据进行整理;

整理成batch_Data的形式,接着就可以输入到模型中训练。

读哪些是由Sampler决定的index,从哪读是由Dataset决定的,怎么读是由getitem决定的。

关注【学姐带你玩AI】公众号

回复“500”免费领取200多篇精选AI必读论文!



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有