Dataloader中的collate_fn参数
政在学习 人气:0以MNIST为例
from torchvision import datasets mnist = datasets.MNIST(root='./data/', train=True, download=True) print(mnist[0])
结果
(<PIL.Image.Image image mode=L size=28x28 at 0x196E3F1D898>, 5)
MINIST数据集的dataset是由一张图片和一个label组成的元组
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:x) for each in dataloader: print(each) break
结果
[(<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105630>, 0), (<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105668>, 2)]
collate_fn为lamda x:x时表示对传入进来的数据不做处理
下面自定义collate_fn看看什么效果
def collate(data): img = [] label = [] for each in data: img.append(each[0]) label.append(each[1]) return img,label dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:collate(x)) for each in dataloader: print(each) break
结果
([<PIL.Image.Image image mode=L size=28x28 at 0x241433A36D8>, <PIL.Image.Image image mode=L size=28x28 at 0x241433A3710>], [9, 3])
说明:若不设置collate_fn参数则会使用默认处理函数
但必须保证传进来的数据都是tensor格式否则会报错
附:DataLoader完整的参数表如下:
class torch.utils.data.DataLoader( dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
DataLoader在数据集上提供单进程或多进程的迭代器
几个关键的参数意思:
- shuffle:设置为True的时候,每个世代都会打乱数据集
- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能
- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留
总结
加载全部内容