当前位置: 首页 > news >正文

PyTorch入门学习(九):神经网络-最大池化使用

目录

一、数据准备

二、创建神经网络模型

三、可视化最大池化效果


一、数据准备

首先,需要准备一个数据集来演示最大池化层的应用。在本例中,使用了CIFAR-10数据集,这是一个包含10个不同类别图像的数据集,用于分类任务。我们使用PyTorch的torchvision库来加载CIFAR-10数据集并进行必要的数据转换。

import torch
import torchvision
from torch.utils.data import DataLoader# 数据集准备
dataset = torchvision.datasets.CIFAR10("D:\\Python_Project\\pytorch\\dataset2", train=False, transform=torchvision.transforms.ToTensor(), download=True)# 使用DataLoader加载数据集,每批次包含64张图像
dataLoader = DataLoader(dataset, batch_size=64)

二、创建神经网络模型

接下来,创建一个简单的神经网络模型,其中包含一个卷积层和一个最大池化层。这个模型将帮助演示最大池化层的效果。首先定义一个Tudui类,该类继承了nn.Module,并在初始化方法中创建了一个卷积层和一个最大池化层。

import torch.nn as nn
from torch.nn import Conv2d
from torch.nn.functional import max_pool2dclass Tudui(nn.Module):def __init__(self):super(Tudui, self).__init()# 卷积层self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)# 最大池化层self.pool = nn.MaxPool2d(kernel_size=2, stride=2)def forward(self, x):x = self.conv1(x)x = self.pool(x)return xtudui = Tudui()
print(tudui)

上述代码中,定义了Tudui类,包括了一个卷积层和一个最大池化层。在forward方法中,数据首先经过卷积层,然后通过最大池化层,以减小图像的维度。

三、可视化最大池化效果

最大池化层有助于减小图像的维度,提取图像中的主要特征。接下来将使用TensorBoard来可视化最大池化的效果,以更好地理解它。首先,导入SummaryWriter类并创建一个SummaryWriter对象。

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("logs")

然后,遍历数据集,对每个批次的图像应用卷积和最大池化操作,并将卷积前后的图像写入TensorBoard。

step = 0
for data in dataLoader:imgs, targets = data# 卷积和最大池化操作output = tudui(imgs)# 将输入图像写入TensorBoardwriter.add_images("input", imgs, step)# 由于TensorBoard不能直接显示多通道图像,我们需要重定义输出图像的大小output = torch.reshape(output, (-1, 6, 15, 15))# 将卷积和最大池化后的图像写入TensorBoardwriter.add_images("output", output, step)step += 1writer.close()

在上述代码中,使用writer.add_images将输入和输出的图像写入TensorBoard,并使用torch.reshape来重定义输出图像的大小,以适应TensorBoard的显示要求。

运行上述代码后,将在TensorBoard中看到卷积和最大池化的效果。最大池化层有助于提取图像中的关键信息,减小图像维度,并提高模型的计算效率。

完整代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
#数据集准备
dataset = torchvision.datasets.CIFAR10("D:\\Python_Project\\pytorch\\dataset2",train=False,transform=torchvision.transforms.ToTensor(),download=True)
#使用dataloader加载数据集,批次数为64
dataLoader = DataLoader(dataset,batch_size=64)class Tudui(nn.Module):def __init__(self):super(Tudui,self).__init__()# 该神经网络调用conv2d进行一层卷积,输入通道为3层(彩色图像为3通道),卷积核大小为3*3,输出通道为6,设置步长为1,padding为0,不进行填充。self.conv1 = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)def forward(self,x):x = self.conv1(x)return xtudui = Tudui()
print(tudui)# 生成日志
writer = SummaryWriter("logs")step = 0
# 输出卷积前的图片大小和卷积后的图片大小
for data in dataLoader:imgs,targets = data# 卷积操作output = tudui(imgs)print(imgs.shape)print(output.shape)writer.add_images("input",imgs,step)"""注意:使用tensorboard输出时需要重新定义图片大小对于输入的图片集imgs来说,tensor.size([64,3,32,32]),即一批次为64张,一张图片为三个通道,大小为32*32对于经过卷积后输出的图片集output来说,tensor.size([64,6,30,30]),通道数变成了6,tensorboard不知道怎么显示通道数为6的图片,所以如果直接输出会报错解决方案:使用reshape方法对outputs进行重定义,把通道数改成3,如果不知道批次数大小,可以使用-1代替,程序会自动匹配批次大小。"""#重定义输出图片的大小output = torch.reshape(output,(-1,3,30,30))# 显示输出的图片writer.add_images("output",output,step)step = step + 1
writer.close()

参考资料:

视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

http://www.lryc.cn/news/213378.html

相关文章:

  • 0基础学习PyFlink——用户自定义函数之UDF
  • 英语小作文模板(06求助+描述;07描述+建议)
  • 为什么感觉假期有时候比上班还累?
  • 推理还是背诵?通过反事实任务探索语言模型的能力和局限性
  • 《利息理论》指导 TCP 拥塞控制
  • Bsdiff,Bspatch 的差分增量升级(基于Win和Linux)
  • 【3妹教我学历史-秦朝史】2 秦穆公-韩原之战
  • 车载控制器
  • 回归预测 | Matlab实现RIME-CNN-SVM霜冰优化算法优化卷积神经网络-支持向量机的多变量回归预测
  • 使用Jaeger进行分布式跟踪:学习如何在服务网格中使用Jaeger来监控和分析请求的跟踪信息
  • 添加多个单元对象
  • 十八、模型构建器(ModelBuilder)快速提取城市建成区——批量掩膜提取夜光数据、夜光数据转面、面数据融合、要素转Excel(基于参考比较法)
  • HarmonyOS开发:基于http开源一个网络请求库
  • 【杂记】Ubuntu20.04装系统,安装CUDA等
  • 040-第三代软件开发-全新波形抓取算法
  • 分享一个基于asp.net的供销社农产品商品销售系统的设计与实现(源码调试 lw开题报告ppt)
  • Java基于SpringBoot的线上考试系统
  • flask socketio 实时传值至html上【需补充实例】
  • C# Onnx P2PNet 人群检测和计数
  • idea提交代码一直提示 log into gitee
  • ATECLOUD如何进行电源模块各项性能指标的测试?
  • Mysql查询训练——50道题
  • 学习笔记|正态分布|图形法|偏度和峰度|非参数检验法|《小白爱上SPSS》课程:SPSS第三讲 | 正态分布怎么检验?看这篇文章就够了
  • Android NDK开发详解之ndk-build 脚本
  • 应用于智慧矿山的皮带跑偏视频分析AI算法
  • vue3 UI组件优化之element-plus按需导入
  • 如何创建 Spring Boot 项目
  • 【经验分享】openGauss容灾集群搭建
  • 互联网应用架构的演进(八大架构的演进过程)
  • ROS自学笔记二十六:导航中激光雷达消息