Pytorch在dataloader类中设置shuffle的随机数种子方式 您所在的位置:网站首页 Python如何设置随机数种子 Pytorch在dataloader类中设置shuffle的随机数种子方式

Pytorch在dataloader类中设置shuffle的随机数种子方式

2024-07-10 03:45| 来源: 网络整理| 查看: 265

PyTorch的数据集DataLoader是十分常用的数据加载和预处理工具,通过将数据传输到GPU并在深度学习过程中进行抽样,而它的shuffle参数可以打乱数据集的顺序,使损失函数更加随机。但同时,我们也可能需要控制随机的行为,以获得可再现的实验结果。下面是两种设置shuffle随机数种子的方法:

方法一:使用torch.utils.data.DataLoader类的WorkerInitFn参数

我们可以使用WorkerInitFn来传递一个函数,来控制数据集加载器的每个工作进程的初始化过程。以下是一个示例的代码段:

import random import torch from torch.utils.data import DataLoader class MyDataset(Dataset): def __init__(self): super().__init__() self.data = list(range(10)) def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] # 设置随机数种子,获得可再现的实验结果 def worker_init_fn(worker_id): random.seed(worker_id) dataset = MyDataset() dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn) for i, batch in enumerate(dataloader): print(batch)

在这个例子中,我们将worker_init_fn设置为一个函数,该函数会在每个工作进程初始化时调用,并使用其工作进程ID作为随机数种子,以控制每个进程数据加载顺序的随机性。这里,使用random.seed来设置随机种子。

当shuffle参数设置为True时,DataLoader会在每个工作进程中打乱数据,并将其放回主进程。 在每个工作进程初始化时,随机数种子被设置成与工作进程ID有关的值。这样,每个进程在打乱数据时使用不同的随机数种子,以确保打乱后的顺序是独立的,而不是互相关联的。

方法二:使用torch.Generator类

我们也可以使用PyTorch的Random模块来设置DataLoader类中的随机数种子。具体做法是将shuffle设置为True,然后使用PyTorch的工具包生成随机数种子。以下是一个示例的代码段:

import torch import torch.utils.data as data_utils torch.manual_seed(42) # 设置随机数种子 # 创建数据集 data = torch.Tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) target = torch.Tensor([1, 1, 0, 0]) dataset = data_utils.TensorDataset(data, target) # 创建DataLoader类 batch_size = 2 dataloader = data_utils.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, generator=torch.Generator().manual_seed(42)) # 打印出来 for batch_idx, (data, target) in enumerate(dataloader): print("Batch index {}, data shape {}, target shape {}".format(batch_idx, data.shape, target.shape))

此例中,我们将DataLoader类的generator参数设置为为torch.Generator().manual_seed(42),shuffle参数设置为True,并使用torch.manual_seed(42)方法设置随机数种子来控制打乱数据的顺序。在这个例子中,generator是torch.Generator对象,我们设置它的随机数种子为42。这样每一次使用DataLoader类,我们都能得到相同的打乱数据顺序。

这两种设置shuffle随机数种子的方式,在控制随机性方面有其各自的优点和适用场景,读者可以根据情况选择更加适合自身需求的方法。

本站文章如无特殊说明,均为本站原创,如若转载,请注明出处:Pytorch在dataloader类中设置shuffle的随机数种子方式 - Python技术站



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有