当前位置: 首页 > news >正文

python/pytorch读取数据集

MNIST数据集

MNIST数据集包含了6万张手写数字([1,28,28]尺寸),以特殊格式存储。本文首先将MNIST数据集另存为png格式,然后再读取png格式图片,开展后续训练

另存为png格式

import torch
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import models, transforms
from torchvision.utils import save_image
from PIL import Image#将MNIST数据集转换为图片
tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1
datasetMNIST = MNIST("./data", train=True, download=True, transform=tf)
pbar = tqdm(datasetMNIST)
for index, (img,cl) in enumerate(pbar):save_image(img, f"./data/MNIST_PNG/x/{index}.png")# 以写入模式打开文件with open(f"./data/MNIST_PNG/c/{index}.txt", "w", encoding="utf-8") as file:# 将字符串写入文件file.write(f"{cl}")

注意:MNIST源数据存放在./data文件下,如果没有数据也没关系,代码会自动从网上下载。另存为png的数据放在了./data/MNIST_PNG/文件下。子文件夹x存放6万张图片,子文件夹c存放6万个文本文件,每个文本文件内有一行字符串,说明该对应的手写数字是几(标签)。

读取png格式数据集

class MyMNISTDataset(Dataset):def __init__(self, data):self.data = datadef __len__(self):return len(self.data)def __getitem__(self, idx):x = self.data[idx][0] #图像y = self.data[idx][1] #标签return x, ydef load_data(dataNum=60000):data = []pbar = tqdm(range(dataNum))for i in pbar:# 指定图片路径image_path = f'./data/MNIST_PNG/x/{i}.png'cond_path=f'./data/MNIST_PNG/c/{i}.txt'# 定义图像预处理preprocess = transforms.Compose([transforms.Grayscale(num_output_channels=1),  # 将图像转换为灰度图像(单通道)transforms.ToTensor()])# 使用预处理加载图像image_tensor = preprocess(Image.open(image_path))# 加载条件文档(tag)with open(cond_path, 'r') as file:line = file.readline()number = int(line)  # 将字符串转换为整数,图像的类别data.append((image_tensor, number))return datadata=load_data(60000)
# 创建数据集实例
dataset = MyMNISTDataset(data)# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
pbar = tqdm(dataloader)for index, (img,cond) in enumerate(pbar):#这里对每一批进行训练...print(f"Batch {index}: img = {img.shape}, cond = {cond}")

load_data函数用于读取数据文件,返回一个data张量。data张量又被用于构造MyMNISTDataset类的对象datasetdataset对象又被DataLoader函数转换为dataloader

dataloader事实上按照batch将数据集进行了分割,4张图片一组进行训练。上述代码的输出如下:

......
Batch 7847: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 1, 5, 2])
Batch 7848: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 2, 6, 0])
Batch 7849: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 3, 0, 9])
Batch 7850: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 2, 9, 5])
Batch 7851: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 2, 4, 4])
Batch 7852: img = torch.Size([4, 1, 28, 28]), cond = tensor([1, 4, 2, 6])
Batch 7853: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 5, 3, 5])
Batch 7854: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 1, 0, 1])
Batch 7855: img = torch.Size([4, 1, 28, 28]), cond = tensor([9, 8, 9, 7])
Batch 7856: img = torch.Size([4, 1, 28, 28]), cond = tensor([4, 6, 6, 7])
Batch 7857: img = torch.Size([4, 1, 28, 28]), cond = tensor([7, 4, 1, 6])
Batch 7858: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 4, 6, 5])
Batch 7859: img = torch.Size([4, 1, 28, 28]), cond = tensor([6, 3, 1, 9])
Batch 7860: img = torch.Size([4, 1, 28, 28]), cond = tensor([5, 5, 8, 6])
Batch 7861: img = torch.Size([4, 1, 28, 28]), cond = tensor([0, 4, 8, 9])
Batch 7862: img = torch.Size([4, 1, 28, 28]), cond = tensor([2, 3, 5, 8])
Batch 7863: img = torch.Size([4, 1, 28, 28]), cond = tensor([8, 0, 0, 6])
......
http://www.lryc.cn/news/268535.html

相关文章:

  • IT安全:实时网络安全监控
  • SQL server使用profiler工具跟踪语句
  • python实现一维傅里叶变换——冈萨雷斯数字图像处理
  • 表单(HTML)
  • spripng 三级缓存,三级缓存的作用是什么? Spring 中哪些情况下,不能解决循环依赖问题有哪些
  • elasticsearch系列六:索引重建
  • GitOps实践指南:GitOps能为我们带来什么?
  • D3485国产芯片+5V工作电压, 内置失效保护电路采用SOP8封装
  • devops使用
  • AI训练师常用的ChatGPT通用提示词模板
  • Java加密算法工具类(AES、DES、MD5、RSA)
  • 探索Go语言的魅力:一门简洁高效的编程语言
  • 【用unity实现100个游戏之19】制作一个3D传送门游戏,实现类似鬼打墙,迷宫,镜子,任意门效果
  • DRF(Django Rest Framework)框架基于restAPI协议规范的知识点总结
  • Linux磁盘与文件系统管理
  • 数字魔法AI绘画的艺术奇迹-用Stable Diffusion挑战无限可能【文末送书-12】
  • 【docker实战】02 用docker安装mysql
  • 循环渲染ForEach
  • 纷享销客华为云:如何让企业多一个选择?
  • 前端实现断点续传文件
  • 复试 || 就业day01(2023.12.27)算法篇
  • JavaWeb——JQuery
  • Python教程:查询Py模块的版本号,有哪些方法?
  • 第一节 初始化项目
  • idea提示unable to import maven project
  • 【Spring】SpringBoot日志
  • HTML+CSS制作动漫绿巨人
  • AGV智能搬运机器人-替代人工工位让物流行业降本增效
  • 【办公技巧】怎么批量提取文件名到excel
  • uniapp实现前端银行卡隐藏中间的数字,及隐藏姓名后两位