当前位置: 首页 > article >正文

PyTorch入门-torchvision

torchvision

torchvision 是 PyTorch 的一个重要扩展库,专门针对计算机视觉任务设计。它提供了丰富的预训练模型、常用数据集、图像变换工具和计算机视觉组件,大大简化了视觉相关深度学习项目的开发流程。

我们可以在Pytorch的官网找到torchvision的文档

在这里插入图片描述

文档中提供了很多数据集

在这里插入图片描述

这里以CIFAR10为例,它是图像分类常用的数据集

CIFAR-10 数据集由 60,000 张 32x32 像素的彩色图像组成,分为 10 个类别,每个类别有 6,000 张图像。其中 50,000 张是训练图像,10,000 张是测试图像。

数据集分为五个训练批次和一个测试批次,每个批次包含 10,000 张图像。测试批次包含每个类别中随机选择的 1,000 张图像。训练批次包含剩余的图像,顺序随机,但某些训练批次可能包含一个类别的更多图像。所有训练批次加起来正好包含每个类别的 5,000 张图像。

在这里插入图片描述
在这里插入图片描述

除了数据集之外,还提供了模型torchvision.models 模块包含了一系列预训练的深度学习模型,广泛应用于图像分类、目标检测、语义分割等任务。

我们可以通过代码下载数据集

import torchvisiontrans_set = torchvision.datasets.CIFAR10(root = "./dataset",train= True,download= True)
test_set = torchvision.datasets.CIFAR10(root = "./dataset",train= False,download= True)

参数列表

  1. root (str):
    • 数据集存储的路径,数据将下载到此目录下。
  2. train (bool, optional):
    • 如果为 True,则加载训练集;如果为 False,则加载测试集。默认值为 True
  3. transform (callable, optional):
    • 一个函数/转换,用于对图像进行预处理,比如数据增强、归一化等。
  4. target_transform (callable, optional):
    • 一个函数/转换,用于对目标(标签)进行处理。
  5. download (bool, optional):
    • 如果为 True,则从网上下载数据集(如果在指定路径中不存在)。默认值为 False

下载完成后可以看到项目目录中的数据集
在这里插入图片描述

我们可以打印一下print("训练集数量:", len(trans_set)) 查看训练集数量

在这里插入图片描述

完整代码如下,可以看到我们的第一个图片是cat

import torchvision# 下载并加载CIFAR10训练数据集
trans_set = torchvision.datasets.CIFAR10(root = "./dataset", train= True, download= True)# 下载并加载CIFAR10测试数据集
test_set = torchvision.datasets.CIFAR10(root = "./dataset", train= False, download= True)# 获取测试集的第一个样本和对应的标签
img, target = test_set[0]
# 显示测试集中的类别标签
print(test_set.classes) # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 显示样本的图像数据
print(img) # <PIL.Image.Image image mode=RGB size=32x32 at 0x1BF002C7710>
# 显示样本的标签
print(target) # 3
# 根据标签索引对应的类别名称
print(test_set.classes[target]) # cat
# 显示图像
# 在这里使用PIL库的Image模块的show方法,直接在屏幕上展示图像
img.show()

这个数据集的图片都比较小(32x32 像素),放大以后虽然这个看起来并不像猫,反而像老鼠,但是它就是cat

在这里插入图片描述

上面我们得到的数据类型是PIL,我们需要转为tensor类型,我们只需要新增一个Compose然后修改dataset代码

# 定义数据集转换
dataset_transform = torchvision.transforms.Compose([# 将图像数据转换为 Tensortorchvision.transforms.ToTensor()    # 还可以对 Tensor 进行归一化,参数分别表示均值和标准差#torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 下载并加载CIFAR10训练数据集
# 参数:
#   root: 指定数据集的保存路径
#   train: 指示是训练数据集(True)还是测试数据集(False)
#   transform: 对数据集中的每个图像应用的转换操作
#   download: 如果数据集不存在于指定路径且设置为True,则会自动下载数据集
trans_set = torchvision.datasets.CIFAR10(root = "./dataset", train= True, transform= dataset_transform,download= True)
# 下载并加载CIFAR10测试数据集,参数同上
test_set = torchvision.datasets.CIFAR10(root = "./dataset", train= False,transform= dataset_transform, download= True)

然后我们执行之后,控制台会打印图片,此时是我们想要的tensor数据类型(tensor类型图片不能使用show()

在这里插入图片描述

我们就可以显示在tensorBoard中

writer = SummaryWriter("pics")
# 获取测试集的10个样本和对应的标签
for i in range(10):img, target = test_set[i]writer.add_image("test_set", img, i)writer.close()

仔细看,能够依稀辨认出第十张图片是车
在这里插入图片描述

在这里插入图片描述

http://www.lryc.cn/news/2387243.html

相关文章:

  • LVS负载均衡群集技术深度解析
  • 18、Python字符串全解析:Unicode支持、三种创建方式与长度计算实战
  • 5月27日复盘-Transformer介绍
  • CSV数据处理全指南:从基础到实战
  • MyBatis-Plus一站式增强组件MyBatis-Plus-kit(更新2.0版本):零Controller也能生成API?
  • 实时数仓flick+clickhouse启动命令
  • 【Git】Commit Hash vs Change-Id
  • Netty学习专栏(六):深度解析Netty核心参数——从参数配置到生产级优化
  • 服务器磁盘按阵列划分为哪几类
  • 在WPF中添加动画背景
  • 【KWDB创作者计划】_KWDB分布式多模数据库智能交通应用——高并发时序处理与多模数据融合实践
  • Android 中的 ViewModel详解
  • Java集合框架与三层架构实战指南:从基础到企业级应用
  • 6个月Python学习计划 Day 2 - 条件判断、用户输入、格式化输出
  • 使用docker容器部署Elasticsearch和Kibana
  • 批量处理合并拆分pdf功能 OCR 准确率高 免费开源
  • Unity—lua基础语法
  • 目标检测 TaskAlignedAssigner 原理
  • Qt popup窗口半透明背景
  • 游戏:元梦之星游戏开发代码(谢苏)
  • TCP协议原理与Java编程实战:从连接建立到断开的完整解析
  • Linux的top命令使用
  • Spring Cloud Gateway 限流实践:基于 Redis 令牌桶算法的网关层流量治理
  • 可视化大屏实现全屏或非全屏
  • java8函数式接口(函数式接口的匿名实现类作为某些方法的入参)
  • linux自有服务
  • UniApp网页版集成海康视频播放器
  • Filter和Interceptor详解(一文了解执行阶段及其流程)
  • 鸿蒙仓颉开发语言实战教程:实现商城应用详情页
  • GitAny - 無需登入的 GitHub 最新倉庫檢索工具