提速 PyTorch 中的 DataLoader

PyTorch 中原版的 DataLoader,如果使用多进程加速数据读取,那么它在每个 Epoch 结束后都会销毁、重建所有 Workers。当一个 Epoch 的迭代时间较短、但迭代次数较多时,训练会比较费时。本文提出了一个解决方案,在原版 DataLoader 的基础上封装了 FastDataLoader,解决上述问题。

以原版 DataLoader 为基类,定义 FastDataLoader

import torch
import torch.utils.data


class _RepeatSampler(object):
def __init__(self, sampler):
self.sampler = sampler

def __iter__(self):
while True:
yield from iter(self.sampler)


class FastDataLoader(torch.utils.data.dataloader.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()

def __len__(self):
return len(self.batch_sampler.sampler)

def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)

接着,使用 FastDataLoader 替换原来的 DataLoader,即可解决问题。