首先上代码:
- import torch
- from torch.utils.data import Dataset, DataLoader
- # 自定义数据集类
- class MyDataset(Dataset):
- def __init__(self, data):
- self.data = data
- def __getitem__(self, index):
- # 返回数据和对应的标签
- return self.data[index], index
- def __len__(self):
- # 返回数据集的大小
- return len(self.data)
- # 创建数据集
- data = [1, 2, 3, 4, 5]
- dataset = MyDataset(data)
- # 创建数据加载器
- batch_size = 2
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
- # 遍历数据加载器
- for batch_data, batch_labels in dataloader:
- print("Batch Data:", batch_data)
- print("Batch Labels:", batch_labels)
- print("---")
在上述示例中,我们首先定义了一个自定义的数据集类MyDataset
,它继承自PyTorch的Dataset
类。在MyDataset
类中,我们实现了__getitem__
方法来获取指定索引的数据和标签,以及__len__
方法来获取数据集的大小。
然后,我们创建了一个包含数据[1, 2, 3, 4, 5]
的数据集实例dataset
。
接下来,我们使用DataLoader
来创建数据加载器dataloader
。我们指定了batch_size
参数,表示每个批次中的样本数量,并设置shuffle=True
来打乱数据顺序。
最后,我们通过迭代数据加载器来遍历数据。每次迭代时,dataloader
会返回一个批次的数据和对应的标签。在这个示例中,我们简单地打印了批次数据和标签。
你可以根据自己的需求自定义数据集类,并使用DataLoader
来加载数据,以便在训练深度学习模型时进行批量处理。