首先上代码:
import torch from torch.utils.data import Dataset, DataLoader # 自定义数据集类 class MyDataset(Dataset): def __init__(self, data): self.data = data def __getitem__(self, index): # 返回数据和对应的标签 return self.data[index], index def __len__(self): # 返回数据集的大小 return len(self.data) # 创建数据集 data = [1, 2, 3, 4, 5] dataset = MyDataset(data) # 创建数据加载器 batch_size = 2 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # 遍历数据加载器 for batch_data, batch_labels in dataloader: print("Batch Data:", batch_data) print("Batch Labels:", batch_labels) print("---")
在上述示例中,我们首先定义了一个自定义的数据集类MyDataset
,它继承自PyTorch的Dataset
类。在MyDataset
类中,我们实现了__getitem__
方法来获取指定索引的数据和标签,以及__len__
方法来获取数据集的大小。
然后,我们创建了一个包含数据[1, 2, 3, 4, 5]
的数据集实例dataset
。
接下来,我们使用DataLoader
来创建数据加载器dataloader
。我们指定了batch_size
参数,表示每个批次中的样本数量,并设置shuffle=True
来打乱数据顺序。
最后,我们通过迭代数据加载器来遍历数据。每次迭代时,dataloader
会返回一个批次的数据和对应的标签。在这个示例中,我们简单地打印了批次数据和标签。
你可以根据自己的需求自定义数据集类,并使用DataLoader
来加载数据,以便在训练深度学习模型时进行批量处理。