使用PyTorch的DataLoader的简单示例

首先上代码:

import torch
from torch.utils.data import Dataset, DataLoader

# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __getitem__(self, index):
        # 返回数据和对应的标签
        return self.data[index], index
    
    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

# 创建数据集
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)

# 创建数据加载器
batch_size = 2
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 遍历数据加载器
for batch_data, batch_labels in dataloader:
    print("Batch Data:", batch_data)
    print("Batch Labels:", batch_labels)
    print("---")

在上述示例中,我们首先定义了一个自定义的数据集类MyDataset,它继承自PyTorch的Dataset类。在MyDataset类中,我们实现了__getitem__方法来获取指定索引的数据和标签,以及__len__方法来获取数据集的大小。

然后,我们创建了一个包含数据[1, 2, 3, 4, 5]的数据集实例dataset

接下来,我们使用DataLoader来创建数据加载器dataloader。我们指定了batch_size参数,表示每个批次中的样本数量,并设置shuffle=True来打乱数据顺序。

最后,我们通过迭代数据加载器来遍历数据。每次迭代时,dataloader会返回一个批次的数据和对应的标签。在这个示例中,我们简单地打印了批次数据和标签。

你可以根据自己的需求自定义数据集类,并使用DataLoader来加载数据,以便在训练深度学习模型时进行批量处理。

发表评论

匿名网友