[ Pytorch ] 基本使用丨1. 数据集准备与导入 + 图片预处理丨 |
您所在的位置:网站首页 › 阿里巴巴数据包怎么导出来的图片 › [ Pytorch ] 基本使用丨1. 数据集准备与导入 + 图片预处理丨 |
目录 一、数据集准备与导入方法。 1、建立文件夹方式。(原始图片数据分类任务首选) 2、建立tensorDataset方式(文件格式数据) 3、建立dataset类 4、多个数据集作为训练数据的导入方法。 二、预处理 1、使用torchvision中的transforms对自建数据进行预处理。 三、DataLoader使用方法 四、Sampler使用方法 一、数据集准备与导入方法。 1、建立文件夹方式。(原始图片数据分类任务首选)http://www.bubuko.com/infodetail-2304938.html PyTorch源码解读之torch.utils.data.DataLoader_AI之路-CSDN博客 2、建立tensorDataset方式(文件格式数据)[莫烦 PyTorch 系列教程] 3.5 – 数据读取 (Data Loader)-PyTorch 中文网 1)、不像上面1中不用管标签的命名,用2方法需要对标签进行标准化,不然在后面训练时会出现下面错误:pytorch出现cannot get repr的错误_qq_27292549的博客-CSDN博客 def label_normal(train_labels): # 转换一下标签y_train中的标签名称:(2, 7, ... , 1500) ——> (0, 1, 2 , 3 ..., 750 ); train_labels_temp = [] for m1_index in train_labels: if not m1_index in train_labels_temp: train_labels_temp.append(m1_index) print('train_labels_temp', train_labels_temp) ki = 0 new_train_labels = [] for train_index in train_labels: if train_index == train_labels_temp[ki]: new_train_labels.append(ki) else: ki = ki + 1 new_train_labels.append(ki) print('new_y_train', new_train_labels) train_labels = new_train_labels print('\n') print('y_train_trans:', train_labels) print('\n') # 转换一下标签名称[结束] train_labels = np.array(train_labels) return train_labels result_feature = scipy.io.loadmat('./evaluation/features/market_ResNet/market_ResNet_result.mat') result_softout = scipy.io.loadmat('./evaluation/features/view_Branch_softmaxout/Market_softmaxout.mat') # target label --train and val train_labels = result_feature['train_label'][0] train_labels = label_normal(train_labels) # 关键 train_labels = torch.LongTensor(train_labels) train_labels = train_labels.view(-1,1) # reshape to fit the val_labels = result_feature['val_label'][0] val_labels = label_normal(val_labels) # 关键 val_labels = torch.LongTensor(val_labels) val_labels = val_labels.view(-1,1) # reshape to fit the image_datasets = {} image_datasets['train'] = torch.utils.data.TensorDataset(r_f_train, train_yaw_back, train_labels) image_datasets['val'] = torch.utils.data.TensorDataset(r_f_val, val_yaw_back, val_labels) dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, shuffle=True, num_workers=16) for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} 3、建立dataset类参考:https://blog.csdn.net/lqp888888/article/details/80481456 import torch.utils.data as data class MyDataset(data.Dataset): def __init__(self, data, labels): self.data= data self.labels = labels def __getitem__(self, index): img, target = self.data[index], self.labels[index] return img, target def __len__(self): return len(self.data) 4、多个数据集作为训练数据的导入方法。 import copy class DATASET(object): def __init__(self,data=[]): self.data = data def __getitem__(self, item): return self.data[item] ************************ * 关键函数 def __add__(self, other): """Adds two datasets together (only the train set).""" train = copy.deepcopy(self.data) for num in other.data: train.append(num) return DATASET(data=train) def __radd__(self, other): """Supports sum([dataset1, dataset2, dataset3]).""" if other == 0: return self else: return self.__add__(other) data1 = [1,2,3,4,5] data2 = [6,7,8,9,10] dset1 = DATASET(data=data1) dset2 = DATASET(data=data2) dlist = [] dlist.append(dset1) dlist.append(dset2) dsetsum = sum(dlist) for num in dsetsum: print(num)上面的关键函数是python内置函数def __add__(self, other):和def __radd__(self, other):。 二、预处理 1、使用torchvision中的transforms对自建数据进行预处理。参考:torchvision.datasets.folder — PyTorch master documentation pytorch加载数据与预处理数据 - pytorch中文网 import torch.utils.data as data class MyDataset(data.Dataset): def __init__(self, data, labels, transforms = None): self.data= data self.labels = labels self.transforms = transforms def __getitem__(self, index): img, target = self.data[index], self.labels[index] img = self.transforms(img) # 关键位置 return img, target def __len__(self): return len(self.data)注意:tranforms操作之后,每一次epoch的训练中,原始数据量不会变,每个epoch中的数据仅仅是通过一个transform操作。 2、transform.function下的图像变换工具:link。 -- 几点备忘 -------- 1、transform的输入输出格式。 import numpy as np import torch from PIL import Image import torchvision.transforms as transforms img = Image.open('/disk2/yu.jpg') transforms_ = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img = transforms_(img) print(img.__class__) img = np.array(img) img = Image.fromarray(img,mode='RGB') # 去掉这个会报错 transforms_2 = transforms.Compose([ transforms.Resize([97, 245], Iage.BICUBIC), transforms.ToTensor()]) img = transforms_2(img) 三、DataLoader使用方法torch.utils.data.DataLoader源代码、文档。 1、collate_fn方法的使用。collate_fn可以设置采用数据集数据成为mini_batch的采样方式。 四、Sampler使用方法pytorch源码阅读(三)Sampler类与4种采样方式 - 知乎 1. 关键点一: Sampler的 __iter__(self) 与 Dataset __getitem__(self, index) 中的 index 相联系。 import torch class MySampler(Sampler): r"""Samples elements sequentially, always in the same order. Arguments: data_source (Dataset): dataset to sample from """ def __init__(self, data_source): self.data_source = data_source def __iter__(self): ################################################################ ### 注意这里,这里是Sampler的关键。是Sampler与DataSet的联系点。 ### ################################################################ ### 这个__iter__(self)方法决定了“一、Dataset”中的__getitem__(self, index)的“index”的调用顺序 ### ### 1. 比如下面这个是按顺序迭代index,也就是说,index是 1,2,3,...这么按顺序迭代出去的。 return iter(range(len(self.data_source))) ### 2. 下面这个是随机吐出index。 return iter(torch.randperm(n).tolist()) def __len__(self): return len(self.data_source)
|
今日新闻 |
点击排行 |
|
推荐新闻 |
图片新闻 |
|
专题文章 |
CopyRight 2018-2019 实验室设备网 版权所有 win10的实时保护怎么永久关闭 |