pytorch定义datase多次重复采样
有的时候训练需要对样本重复抽样为一个batch,可以按如下格式定义:
class TrainLoader(Dataset):def __init__(self, fns, repeat=1):super(TrainLoader, self).__init__()self.length = len(fns) # 数据数量self.repeat = repeat # 数据重复次数def __getitem__(self, idx):idx = idx % self.length def __len__(self):return self.length * self.repeat