深度学习pytorch分割数据集的方法(将大数据集改小更加易于训练) 您所在的位置:网站首页 如何把大药丸分成小药丸 深度学习pytorch分割数据集的方法(将大数据集改小更加易于训练)

深度学习pytorch分割数据集的方法(将大数据集改小更加易于训练)

2024-07-16 01:00| 来源: 网络整理| 查看: 265

问题一:划分训练集和验证集: import os from shutil import copy, rmtree import random def mk_file(file_path: str): if os.path.exists(file_path): # 如果文件夹存在,则先删除原文件夹在重新创建 rmtree(file_path) os.makedirs(file_path) def main(): # 保证随机可复现 random.seed(0) # 将数据集中10%的数据划分到验证集中 split_rate = 0.1 # 指向你解压后的flower_photos文件夹 cwd = os.getcwd() data_root = os.path.join(cwd, "my_data3") origin_flower_path = os.path.join(data_root, "images") assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path) flower_class = [cla for cla in os.listdir(origin_flower_path) if os.path.isdir(os.path.join(origin_flower_path, cla))] # 建立保存训练集的文件夹 train_root = os.path.join(data_root, "train") mk_file(train_root) for cla in flower_class: # 建立每个类别对应的文件夹 mk_file(os.path.join(train_root, cla)) # 建立保存验证集的文件夹 val_root = os.path.join(data_root, "val") mk_file(val_root) for cla in flower_class: # 建立每个类别对应的文件夹 mk_file(os.path.join(val_root, cla)) for cla in flower_class: cla_path = os.path.join(origin_flower_path, cla) images = os.listdir(cla_path) num = len(images) # 随机采样验证集的索引 eval_index = random.sample(images, k=int(num*split_rate)) for index, image in enumerate(images): if image in eval_index: # 将分配至验证集中的文件复制到相应目录 image_path = os.path.join(cla_path, image) new_path = os.path.join(val_root, cla) copy(image_path, new_path) else: # 将分配至训练集中的文件复制到相应目录 image_path = os.path.join(cla_path, image) new_path = os.path.join(train_root, cla) copy(image_path, new_path) print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar print() print("processing done!") if __name__ == '__main__': main()

当由于硬件设备的影响或者计算量参数的庞大可以使用以下方法,适当的减少数据量以方便进行训练: 方法1.如果是标准的数据集:(文件每组分类都有规定的图片)

可以随机删除一些数据

import os import random def delete_images(dir_path, num_to_delete): # 获取目录中的所有文件 files = os.listdir(dir_path) # 过滤出图像文件 image_files = [f for f in files if f.endswith('.jpg') or f.endswith('.png')] # 如果要删除的图片数量大于图片总数,则只删除所有图片 num_to_delete = min(num_to_delete, len(image_files)) # 随机选择要删除的图片 images_to_delete = random.sample(image_files, num_to_delete) # 删除图片 for image in images_to_delete: os.remove(os.path.join(dir_path, image)) # 使用示例:从当前目录中删除 5 张图片 delete_images(r'F:\Python\BCNN\鸟类细粒度分类\my_data\images\010.Red_winged_Blackbird', 20)

方法二:如果是torchvision里面的数据集,比如

 trainset = datasets.CIFAR10(root=image_path,train=True,download=False,                                 transform=data_transform['train'])

 分割数据集可以采用:

class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。

 代码实操:

trainset1 = datasets.CIFAR10(root=image_path,train=True,download=False, transform=data_transform['train']) # 定义索引列表,选择前5000张图像进行训练 subset_indices1 = list(range(10000)) trainset = Subset(trainset1, subset_indices1) trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0) # 将class_to_idx属性添加到Subset对象中 trainset.class_to_idx = trainset1.class_to_idx train_steps = len(trainloader) print(train_steps) testset1 = datasets.CIFAR10(root=image_path,train=False,download=False, transform=data_transform['val']) subset_indices2 = list(range(2000)) testset = Subset(testset1, subset_indices2) testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0) # val_steps = len(testloader) bird_list = trainset.class_to_idx

这里要注意的是,选择要划分的数据集后,要使用trainset.class_to_idx = trainset1.class_to_idx将class_to_idx添加到Subset对象中。

这样就可以实现指定数据集数量的选取和训练啦

参考文章:

【1】Pytorch划分数据集的方法:torch.utils.data.Subset - 爱学英语的程序媛 - 博客园 (cnblogs.com)

【2】 chatGPT 



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有