当前位置: 首页 > news >正文

PyTorch入门学习(八):神经网络-卷积层

目录

一、数据准备

二、创建卷积神经网络模型

三、可视化卷积前后的图像


一、数据准备

首先,需要准备一个数据集来演示卷积层的应用。在这个示例中,使用了CIFAR-10数据集,该数据集包含了10个不同类别的图像数据,用于分类任务。使用PyTorch的torchvision库来加载CIFAR-10数据集,并进行必要的数据转换。

import torch
import torchvision
from torch.utils.data import DataLoader# 数据集准备
dataset = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2", train=False, transform=torchvision.transforms.ToTensor(), download=True)# 使用DataLoader加载数据集,每批次包含64张图像
dataLoader = DataLoader(dataset, batch_size=64)

二、创建卷积神经网络模型

接下来,创建一个简单的卷积神经网络模型,以演示卷积层的使用。这个模型包含一个卷积层,其中设置了输入通道数为3(因为CIFAR-10中的图像是彩色的,有3个通道),卷积核大小为3x3,输出通道数为6,步长为1,填充为0。

import torch.nn as nn
from torch.nn import Conv2dclass Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()# 卷积层self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)def forward(self, x):x = self.conv1(x)return xtudui = Tudui()
print(tudui)

上述代码定义了一个Tudui类,该类继承了nn.Module,并在初始化方法中创建了一个卷积层。forward方法定义了数据在模型中的前向传播过程。

三、可视化卷积前后的图像

卷积层通常会改变图像的维度和特征。使用TensorBoard来可视化卷积前后的图像以更好地理解卷积操作。首先,导入SummaryWriter类,并创建一个SummaryWriter对象用于记录日志。

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("logs")

然后,使用DataLoader遍历数据集,对每个批次的图像应用卷积操作,并将卷积前后的图像以及输入的图像写入TensorBoard。

step = 0
for data in dataLoader:imgs, targets = data# 卷积操作output = tudui(imgs)# 将输入图像写入TensorBoardwriter.add_images("input", imgs, step)# 由于TensorBoard不能直接显示具有多个通道的图像,我们需要重定义输出图像的大小output = torch.reshape(output, (-1, 3, 30, 30))# 将卷积后的图像写入TensorBoardwriter.add_images("output", output, step)step += 1writer.close()

在上述代码中,使用writer.add_images将输入和输出的图像写入TensorBoard,并使用torch.reshape来重定义输出图像的大小,以满足TensorBoard的显示要求。

运行上述代码后,将在TensorBoard中看到卷积前后的图像,有助于理解卷积操作对图像的影响。

完整代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
#数据集准备
dataset = torchvision.datasets.CIFAR10("D:\\Python_Project\\pytorch\\dataset2",train=False,transform=torchvision.transforms.ToTensor(),download=True)
#使用dataloader加载数据集,批次数为64
dataLoader = DataLoader(dataset,batch_size=64)class Tudui(nn.Module):def __init__(self):super(Tudui,self).__init__()# 该神经网络调用conv2d进行一层卷积,输入通道为3层(彩色图像为3通道),卷积核大小为3*3,输出通道为6,设置步长为1,padding为0,不进行填充。self.conv1 = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)def forward(self,x):x = self.conv1(x)return xtudui = Tudui()
print(tudui)# 生成日志
writer = SummaryWriter("logs")step = 0
# 输出卷积前的图片大小和卷积后的图片大小
for data in dataLoader:imgs,targets = data# 卷积操作output = tudui(imgs)print(imgs.shape)print(output.shape)writer.add_images("input",imgs,step)"""注意:使用tensorboard输出时需要重新定义图片大小对于输入的图片集imgs来说,tensor.size([64,3,32,32]),即一批次为64张,一张图片为三个通道,大小为32*32对于经过卷积后输出的图片集output来说,tensor.size([64,6,30,30]),通道数变成了6,tensorboard不知道怎么显示通道数为6的图片,所以如果直接输出会报错解决方案:使用reshape方法对outputs进行重定义,把通道数改成3,如果不知道批次数大小,可以使用-1代替,程序会自动匹配批次大小。"""#重定义输出图片的大小output = torch.reshape(output,(-1,3,30,30))# 显示输出的图片writer.add_images("output",output,step)step = step + 1
writer.close()

参考资料:

视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

http://www.lryc.cn/news/209992.html

相关文章:

  • 【EI会议征稿】 2024年遥感、测绘与图像处理国际学术会议(RSMIP2024)
  • MySQL 8 - 处理 NULL 值 - is null、=null、is not null、<> null 、!= null
  • 高教社杯数模竞赛特辑论文篇-2018年C题:大型百货商场会员画像描述(附获奖论文及MATLAB代码实现)
  • #力扣:2315. 统计星号@FDDLC
  • 设计模式——单例模式详解
  • 一、W5100S/W5500+RP2040树莓派Pico<静态配置网络信息>
  • 【C++的OpenCV】第十四课-OpenCV基础强化(二):访问单通道Mat中的值
  • elementUI el-collapse 自定义折叠面板icon 和 样式 或文字展开收起
  • 如何用个人数据Milvus Cloud知识库构建 RAG 聊天机器人?(上)
  • 2023年江西省“振兴杯”工业互联网安全技术技能大赛暨全国大赛江西选拔赛 Write UP
  • PostMan 之 Mock 接口测试
  • LuatOS-SOC接口文档(air780E)--libgnss - NMEA数据处理
  • 基于华为云 IoT 物联网平台实现家居环境实时监控
  • 【开源框架】Glide的图片加载流程
  • win10下Mariadb绿色版安装步骤
  • wiresharak捕获DNS
  • vue源码分析(一)——源码目录说明
  • 【深度学习】吴恩达课程笔记(二)——浅层神经网络、深层神经网络
  • UI自动化概念 + Web自动化测试框架介绍
  • 在 macOS 上的多个 PHP 版本之间切换
  • 地址解析协议ARP
  • Go学习第十三章——Gin入门与路由
  • [减脂期食谱] 自制千岛酱
  • Android 系统架构
  • 【Docker】Python Flask + Redis 练习
  • shell_52.Linux测试与其他网络主机的连通性脚本
  • OpenCV C++ 图像处理实战 ——《缺陷检测》
  • Python操作MySQL基础使用
  • 【pytorch】pytorch中的高级索引
  • 基于图像识别的自动驾驶汽车障碍物检测与避障算法研究