使用PyTorch的DataLoader的简单示例

首先上代码:

  1. import torch
  2. from torch.utils.data import Dataset, DataLoader
  3.  
  4. # 自定义数据集类
  5. class MyDataset(Dataset):
  6. def __init__(self, data):
  7. self.data = data
  8. def __getitem__(self, index):
  9. # 返回数据和对应的标签
  10. return self.data[index], index
  11. def __len__(self):
  12. # 返回数据集的大小
  13. return len(self.data)
  14.  
  15. # 创建数据集
  16. data = [1, 2, 3, 4, 5]
  17. dataset = MyDataset(data)
  18.  
  19. # 创建数据加载器
  20. batch_size = 2
  21. dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
  22.  
  23. # 遍历数据加载器
  24. for batch_data, batch_labels in dataloader:
  25. print("Batch Data:", batch_data)
  26. print("Batch Labels:", batch_labels)
  27. print("---")

在上述示例中,我们首先定义了一个自定义的数据集类MyDataset,它继承自PyTorch的Dataset类。在MyDataset类中,我们实现了__getitem__方法来获取指定索引的数据和标签,以及__len__方法来获取数据集的大小。

然后,我们创建了一个包含数据[1, 2, 3, 4, 5]的数据集实例dataset

接下来,我们使用DataLoader来创建数据加载器dataloader。我们指定了batch_size参数,表示每个批次中的样本数量,并设置shuffle=True来打乱数据顺序。

最后,我们通过迭代数据加载器来遍历数据。每次迭代时,dataloader会返回一个批次的数据和对应的标签。在这个示例中,我们简单地打印了批次数据和标签。

你可以根据自己的需求自定义数据集类,并使用DataLoader来加载数据,以便在训练深度学习模型时进行批量处理。

发表评论

匿名网友

拖动滑块以完成验证
加载失败