pytorch读取文件夹内所有图像 pytorch 批量读取图片 |
您所在的位置:网站首页 › rolexAUTOMATIC的所有图片 › pytorch读取文件夹内所有图像 pytorch 批量读取图片 |
pytorch 在载入数据时用torchvision.datasets.ImageFolder 配合 torch.utils.data.DataLoader 很方便,但是只能遍历图片和图片的标签,无法灵活的获取图片的其他信息,比如图片的名字,本文介绍如何定义自己的 ImageFolder,在使用 Dataloader 时实现获取图片名字的功能! 文章目录1 ImageFolder and DataLoader2 OwnFolder and DataLoader3 transforms 1 ImageFolder and DataLoader 以分类为例,用 pytorch 的 torchvision.datasets.ImageFolder 配合 torch.utils.data.DataLoader 即可对数据按类别进行读取、预处理、分成 batch import torchvision import torch train_dataset = torchvision.datasets.ImageFolder( train_data_pth, transforms.Compose([ transforms.Resize(input_size,interpolation=2), # resize transforms.ToTensor(), # ToTensor normalize,])) # Normalization train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, # set batchsize shuffle=False, num_workers=n_worker, pin_memory=True)参考 pytorch训练自己图像分类数据集ImageFolder 中 train_data_pth 是存放数据集的文件夹,文件结构应该如下 train_data_pth class1 xxx.jpg ... class2 xxx.jpg ... ... classn xxx.jpg ...Dataloader 的参数介绍如下 dataset:加载的数据集(Dataset对象)batch_size:batch sizeshuffle:是否将数据打乱sampler: 样本抽样,后续会详细介绍num_workers:使用多进程加载的进程数,0 代表不使用多进程collate_fn: 如何将多个样本数据拼接成一个 batch,一般使用默认的拼接方式即可’pin_memory:是否将数据保存在pin memory 区,pin memory 中的数据转到 GPU 会快一些drop_last:dataset中的数据个数可能不是 batch_size 的整数倍,drop_last 为 True 会将多出来不足一个batch的数据丢弃官网中 Dataloader 的介绍如下(https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) 强调下,num_workers 表示一次可以装载 num_workers 个 batch,而不是一次装载一个 batch。 在训练和测试时,可用如下循环来对数据进行操作 for batch_images, batch_labels in train_loader: pass把数据集的文件夹建立好,直接调用 ImageFolder 和 DataLoader 来进行数据的载入分批读取确实很方便,但是如果我们想知道哪些图片分类错误了, train_loader loader 中仅有 image(图片) 和 label 属性,没有 image name(图片名称) 属性,有些力不从心! 因此,我们可以自己写 ImageFolder 来实现读取 image、label、image name 的功能,当然熟悉这个流程后,以后可以进行更个性化的操作! 参考 从零开始深度学习Pytorch笔记(11)—— DataLoader类 补充 collate_fn参数用于是否需要以自定义的方式组织一个batch, 例子中将一个mini-batch的数据组织成numpy.ndarray的类型. 默认情况下collate_fn=None时,数据以元组的方式返回. 比如将多个numpy小数组组合成一个大的numpy数组: def ssd_dataset_collate(batch): # print('ssd_dataset_collate函数被执行...') images = [] bboxes = [] for img, box in batch: images.append(img) bboxes.append(box) images = np.array(images) bboxes = np.array(bboxes) return images, bboxes gen = DataLoader(train_dataset, \ batch_size=Batch_size, \ num_workers=8, \ pin_memory=True,\ drop_last=True, \ collate_fn=ssd_dataset_collate) # collate_fn=None)2 OwnFolder and DataLoader自己写数据读取和预处理,来替代 torchvision.datasets.ImageFolder 的功能,具体实现如下 class Own_Dataset 所示 class Own_Dataset(Dataset): def __init__(self, image_label_list, transform=None): super().__init__() self.samples_list = image_label_list # xxx.jpg class1 self.transform = transform # pre-processing of data def __getitem__(self, index): img_name = self.samples_list[index][0] # absolute path of image name with open(img_name,"rb") as f: img = Image.open(f).convert("RGB") # load image label = self.samples_list[index][1] # image label if img is None: print(img_name) if self.transform is not None: img = self.transform(img) return img, label, img_name def __len__(self): return len(self.samples_list)其中 image_label_list 为列表,存放着图片的绝对路径以及标签信息,格式如下 [(/train_data_pth/calss1/1.jpg,class1), (/train_data_pth/calss1/2.jpg,class1), ..., (/train_data_pth/calssn/m.jpg,classn)]想实现更多功能,在 def __getitem__(self, index): 中定义即可, __getitem__:实例[idx] 时触发配合 DataLoader 使用 train_loader = torch.utils.data.DataLoader( Own_Dataset(image_label_list=val_list, transform=transforms.Compose([ transforms.Resize(input_size,interpolation=2), # resize transforms.ToTensor(), normalize,])), batch_size=test_batch_size, shuffle=False, num_workers=n_worker, pin_memory=True)训练测试时,就可以访问图片,类别以及图片名信息了,如下所示 for batch_images, batch_labels,batch_names in train_loader: pass3 transforms下面介绍部分 torchvision.transforms 方法 更多的 torchvision.transforms 方法可以参考官网介绍 https://pytorch.org/docs/stable/torchvision/transforms.html train_dataset = datasets.ImageFolder( train_data_pth, transforms.Compose([ transforms.Resize(scale_size,interpolation=2), transforms.RandomRotation(5), transforms.ColorJitter(brightness=0.1,contrast=0.1, saturation=0.1,hue=0.1), transforms.FiveCrop(input_size), transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), transforms.Lambda(lambda crops: torch.stack([transforms.Normalize( mean = [0.5,0.5,0.5], std = [0.5,0.5,0.5])(crop) for crop in crops])) ]))input_size 和 scale_size 写成元组的形式,eg,(224,224) 和 (256,256) Normalize 时注意 mean 和 std 一定要除以 255,值介于 0~1 之间 FiveCrop 或者 TenCrop 时,测试代码也需要进行相应的调整,如下 原来 out = net(batch_images)现在 bs, ncrops, c, h, w, = batch_images.size() result = net(batch_images.view(-1,c,h,w)) out = result.view(bs,ncrops,-1).mean(1)
|
今日新闻 |
点击排行 |
|
推荐新闻 |
图片新闻 |
|
专题文章 |
CopyRight 2018-2019 实验室设备网 版权所有 win10的实时保护怎么永久关闭 |