Python day38
@浙大疏锦行 python day38.
内容:
Dataset 和 Dataloader类
- Dataset相当于针对单个数据进行处理,它指定了数据的目录、针对每个数据进行的处理(样本级预处理)等,Dataset必须实现getitem()方法和len()方法,一个用于获取单个样本,一个用于返回样本总数,而Dataloader就可以通过提供的方法来读取数据
- getitem和len方法都是特殊方法(__fun__),可以支持如下访问方式:class_name[idx],len(class_name)
- Dataloader则是直接针对Dataset的数据进行使用,包括读取batch以及打乱(shuffle)等,不需要自定义方法,只需要指定batch、shuffle等参数即可
# 1. 数据预处理,该写法非常类似于管道pipeline
# transforms 模块提供了一系列常用的图像预处理操作# 先归一化,再标准化
transform = transforms.Compose([transforms.ToTensor(), # 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,), (0.3081,)) # MNIST数据集的均值和标准差,这个值很出名,所以直接使用
])# 2. 加载MNIST数据集,如果没有会自动下载
train_dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,transform=transform
)# 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=64, # 每个批次64张图片,一般是2的幂次方,这与GPU的计算效率有关shuffle=True # 随机打乱数据
)test_loader = DataLoader(test_dataset,batch_size=1000 # 每个批次1000张图片# shuffle=False # 测试时不需要打乱数据
)