当前位置: 首页 > news >正文

Python:torch.nn.Conv1d(), torch.nn.Conv2d()和torch.nn.Conv3d()函数理解

Python:torch.nn.Conv1d(), torch.nn.Conv2d()和torch.nn.Conv3d()函数理解

1. 函数参数

在torch中的卷积操作有三个,torch.nn.Conv1d(),torch.nn.Conv2d()还有torch.nn.Conv3d(),这是搭建网络过程中常用的网络层,为了用好卷积层,需要知道这些参数代表的含义。

这三种不同的卷积的输入参数是相同的,所以只看一个就可以。

def __init__(self,in_channels: int,out_channels: int,kernel_size: _size_2_t,stride: _size_2_t = 1,padding: Union[str, _size_2_t] = 0,dilation: _size_2_t = 1,groups: int = 1,bias: bool = True,padding_mode: str = 'zeros',  # TODO: refine this typedevice=None,dtype=None

这里面的参数网上有很多说明,重点是怎么理解和使用。

2. 参数理解

这里面重点是in_channels参数,这个是代表数据输入的通道,很多说明这个通道是利用torch.nn.Conv2d处理图片数据来进行说明的,代表的是图片的通道数,然后面的两个参数对应着图片的长度和宽度。

下面是本人对这参数的理解过程:

  • 首先对于torch.nn.Conv函数,所接受的数据是可以带有batch维度的,也可以不带有batch维度,这就表示对于torch.nn.Conv2d可以接受的数据包括3维数据或者4维数据,

如:

conv2 = torch.nn.Conv2d(16, 120, 3, stride=2)
input2_3 = torch.randn(16, 5, 5)
output2_3 = conv2(input2_3)
print(output2_3.shape)input2_4 = torch.randn(20, 16, 5, 5)
output2_4 = conv2(input2_4)
print(output2_4.shape)

该段得到的输出为:

torch.Size([120, 2, 2])
torch.Size([20, 120, 2, 2])

这是因为input2_4只是多了一个维度batch在第一个维度上,如果输入的数据是2维的或者5维的,就会提示如下的错误:指明只能接受3维的数据或者4维的数据.

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [20, 20, 16, 5, 5]

这其实就说明了根据自己数据维度选择合适的torch.nn.Conv, 例如,如果数据是2维的,那么就选择torch.nn.Conv1d,这个可以接收传入的数据维度可以是2维,或者是带有batch维度的3维数据。

之后需要注意的是in_channels参数其实对应的就是传入数据的第一个维度(不带有batch)或者带有batch的第二个维度,这个要和in_channels参数相同。

可以理解成这个in_channels就是表示了有多个卷积核在参与计算,那么剩下的维度正好就是卷积核的维度,

如对于torch.nn.Conv3d,传入的数据最少是4维数据,(不带有batch),那么第一维的数据应该等于in_channels,然后剩下三维正好的是卷积核的维度。
如:

conv3 = torch.nn.Conv3d(16, 120, 3, stride=2)
input3 = torch.randn(16, 5, 5, 5)
output3 = conv3(input3)
print(output3.shape)

会得到

torch.Size([120, 2, 2, 2])

这个卷积核是333,相当于有16个卷积核,每个卷积核在16维的数据上依次计算。

其他的作为输出影响的是数据的维度大小,但是out_channels又决定了输出数据的第一个维度,(不带有batch),就可以依然用这个方式思考。

针对后面几维数据的大小,由其他的参数决定,这个有公式可以计算,懒得算也可以直接打印输出看一下维度。

http://www.lryc.cn/news/183231.html

相关文章:

  • scala 连接 MySQL 数据库案例
  • guava工具类常用方法
  • CSShas伪类选择器案例附注释
  • nodejs+vue中医体质的社区居民健康管理系统elementui
  • Kotlin中reified 关键字
  • Linux命令(95)之alias
  • DHCPsnooping 配置实验(2)
  • Qt 综合练习小项目--反金币(2/2)
  • 安装matplotlib__pygame,以pycharm调入模块
  • 编写可扩展的软件:架构和设计原则
  • 算法-排序算法
  • Android_Monkey_测试执行策略及标准
  • windows安装nginx
  • Java日期的学习篇
  • spark on hive
  • Linux Vi编辑器基础操作指南
  • WEB3 创建React前端Dapp环境并整合solidity项目,融合项目结构便捷前端拿取合约 Abi
  • rust运算
  • 游戏引擎,脚本管理模块
  • 2023年7月工作经历三
  • 1801_codesys产品主样本了解
  • flink的计时器
  • @SpringBootApplication剖析
  • 浅谈wor2vec,RNN,LSTM,Transfermer之间的关系
  • 【11】c++设计模式——>单例模式
  • 深度学习-卷积神经网络-AlexNET
  • 人机关系不是物理关系也不是数理关系
  • <html dir=ltr>是什么意思?
  • 工厂模式:简化对象创建的设计思想 (设计模式 四)
  • 【2023最新】微信小程序中微信授权登录功能和退出登录功能实现讲解