Pytorch Dataset和DataLoader
先聊一聊:Dataset和DataLoader是Pytorch提供的两个用于读取数据的类。我们要新建一个Dataset类继承Dataset,重写__init__、__getitem__和__len__三个方法,分别用于构造对象、获取每个数据和获取数据总数,本质就是将数据读取到Dataset中,通过Datadet[0](等价于Dataset.__getitem__(0))可以直接访问数据元素。上述这种访问方式属于列表,这意味着需要在__getitem__中将数据送入列表中,并处理数据,让其变成tensor形式,这样当我们直接用Dataset[idx]时,就会返回一个tensor类型的数据。DataLoader的使用比较简单,我们将新建的Dataset对象作为参数送入,并给定batch等其他参数,就会返回一个DataLoader对象,但与Dataset不同的是,DataLoader无法通过索引直接访问,因为它是Iterable式数据集,只能通过for data in DataLoader的形式访问。
一、Dataset
torch.utils.data.Dataset 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。
所谓数据集,其实就是一个负责处理索引(index)到样本(sample)映射的一个类(class)。
Pytorch提供两种数据集: Map式数据集 Iterable式数据集
1.Map式数据集
一个Map式的数据集必须要重写getitem(self, index),len(self) 两个内建方法,用来表示从索引到样本的映射(Map)。
这样一个数据集dataset,举个例子,当使用dataset[idx]命令时,可以在你的硬盘中读取你的数据集中第idx张图片以及其标签(如果有的话);len(dataset)则会返回这个数据集的容量。
自定义类大致是这样的:
class CustomDataset(data.Dataset):#需要继承data.Dataset |
载入图像的例子:
class CIFAR10(Dataset): |
2.Iterable式数据集
一个Iterable(迭代)式数据集是抽象类data.IterableDataset的子类,并且覆写了iter方法成为一个迭代器。这种数据集主要用于数据大小未知,或者以流的形式的输入,本地文件不固定的情况,需要以迭代的方式来获取样本索引。
这一块先mark着,因为还没有使用过。
二、DataLoader
一般来说PyTorch中深度学习训练的流程是这样的: 1. 创建Dateset 2. Dataset传递给DataLoader 3. DataLoader迭代产生训练数据提供给模型
对应的一般都会有这三部分代码
# 创建Dateset(可以自定义) |
到这里应该就PyTorch的数据集和数据传递机制应该就比较清晰明了了。Dataset负责建立索引到样本的映射,DataLoader负责以特定的方式从数据集中迭代的产生 一个个batch的样本集合。在enumerate过程中实际上是dataloader按照其参数sampler规定的策略调用了其dataset的getitem方法。
1.参数介绍
先看一下实例化一个DataLoader所需的参数。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, |
dataset
(Dataset) – 定义好的Map式或者Iterable式数据集。batch_size
(python:int, optional) – 一个batch含有多少样本 (default: 1)。shuffle
(bool, optional) – 每一个epoch的batch样本是相同还是随机 (default: False)。sampler
(Sampler, optional) – 决定数据集中采样的方法. 如果有,则shuffle参数必须为False。batch_sampler
(Sampler, optional) – 和 sampler 类似,但是一次返回的是一个batch内所有样本的index。和 batch_size, shuffle, sampler, and drop_last 三个参数互斥。num_workers
(python:int, optional) – 多少个子程序同时工作来获取数据,多线程。 (default: 0)collate_fn
(callable, optional) – 合并样本列表以形成小批量。pin_memory
(bool, optional) – 如果为True,数据加载器在返回前将张量复制到CUDA固定内存中。drop_last
(bool, optional) – 如果数据集大小不能被batch_size整除,设置为True可删除最后一个不完整的批处理。如果设为False并且数据集的大小不能被batch_size整除,则最后一个batch将更小。(default: False)timeout
(numeric, optional) – 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。 (default: 0)- *
worker_init_fn
(callable, optional) – 每个worker初始化函数 (default: None)
train_data_loader = DataLoader(train_data, batch_size=64, shuffle=True, drop_last=True) |
三、总结
也就是说Dataset需要和DataLoader一起使用,从Dataset中获取的数据必须是已经处理过的tensor类型数据,Dataset使用Map式可以使用索引直接访问。DataLoader需要以Dataset和一些参数作为形参得到对象,然后使用迭代时for循环访问每组数据,如果是图像一般是(batch, 3, 32, 32)的数据形式。