当前位置: 首页 > news >正文

神经网络的基本骨架—nn.Module使用

一、pytorch官网中torch.nn的相关简介

可以看到torch.nn中有许多模块:

二、Containers模块

1、MODULE(CLASS : torch.nn.Module)

import torch.nn as nn
import torch.nn.functional as Fclass Model(nn.Module):#nn.Module---所有神经网络模块的基类。def __init__(self): #初始化super(Model, self).__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x): #前向计算x = F.relu(self.conv1(x))return F.relu(self.conv2(x))

forward(*input)

Defines the computation performed at every call. Should be overridden by all subclasses.

2、搭建神经网络模型

import torch
import torch.nn as nn
import torch.nn.functional as F
# 定义自己的神经网络模板
class Lemon(nn.Module):def __init__(self) -> None:super().__init__()def forward(self,input):output = input + 1return output
# 创建神经网络
lemon = Lemon()
x = torch.tensor(1.0)
output = lemon(x)
print(output)

三、Convolution Layers

  1. nn.Conv1d/nnCon2d

  • input – input tensor of shape (minibatch,in_channels,iH,iW)输入

  • weight – filters of shape (out_channels,groupsin_channels,kH,kW)权重/卷积核

  • bias – optional bias tensor of shape (out_channels). Default: None偏置

  • stride – the stride of the convolving kernel. Can be a single number or a tuple (sH, sW). Default: 1步进/长 SH和SW分别控制横向的步进和纵向的步进

  • padding – implicit paddings on both sides of the input. Can be a single number or a tuple (padH, padW). Default: 0

  • dilation – the spacing between kernel elements. Can be a single number or a tuple (dH, dW). Default: 1

  • groups – split input into groups, in_channelsin_channels should be divisible by the number of groups. Default: 1

import torch
import torch.nn.functional as F
# 输入
input = torch.tensor([[1,2,0,3,1],[0,1,2,3,1],[1,2,1,0,0],[5,2,3,1,1],[2,1,0,1,1]])
# 卷积核
kernel = torch.tensor([[1,2,1],[0,1,0],[2,1,0]])
print(input.shape) #torch.Size([5, 5])
print(kernel.shape) #torch.Size([3, 3])
#官方文档中输入input和卷积核weight需要四个参数——>input tensor of shape (minibatch,in_channels,iH,iW)
#所以可以使用reshape二参变四参
input = torch.reshape(input,(1,1,5,5)) #torch.Size([1, 1, 5, 5])
kernel = torch.reshape(kernel,(1,1,3,3)) #torch.Size([1, 1, 3, 3])
print(input.shape) #torch.Size([5, 5])
print(kernel.shape) #torch.Size([3, 3])output = F.conv2d(input,kernel,stride=1)
print(output)

一般来讲,输出的维度 = 输入的维度 - 卷积核大小/stride + 1

padding =1,为上下左右各填充一行,空的地方默认为0

http://www.lryc.cn/news/14023.html

相关文章:

  • 面试官:你是怎样进行react组件代码复用的
  • arxiv2017 | 用于分子神经网络建模的数据增强 SMILES Enumeration
  • 倒计时2天!TO B人的传统节日,2023年22客户节(22DAY)
  • java版工程管理系统Spring Cloud+Spring Boot+Mybatis实现工程管理系统源码
  • 数据结构刷题(六):142环形链表II、242有效的字母异位词、383赎金信、349两个数组的交集
  • OpenGL学习日记之光照计算
  • 七大排序经典排序算法
  • 设计模式—“对象性能”
  • 基于Spring Boot的零食商店
  • Python语言的优缺点
  • 3款强大到离谱的电脑软件,个个提效神器,从此远离加班
  • vue3 使用typescript小结
  • PYTHON爬虫基础
  • JavaScript刷LeetCode模板技巧篇(一)
  • ros-sensor_msgs/PointCloud2消息内容解释
  • LeetCode 每日一题2347. 最好的扑克手牌
  • MMPBSA计算--基于李继存老师gmx_mmpbsa脚本
  • Kafka优化篇-压测和性能调优
  • MinIo-SDK
  • 系统分析师真题2018试卷相关概念一
  • 身为大学生,你不会还不知道有这些学生福利吧!!!!
  • 试题 算法训练 藏匿的刺客
  • JavaWab开发的总括以及HTML知识
  • Oracle数据库文件(*.dbf)迁移【图文教程】
  • Java中如何创建和使用对象?
  • Spring Cloud Alibaba--ActiveMQ微服务详解之消息队列(四)
  • 32岁,薪水被应届生倒挂,裸辞了
  • 蓝桥杯训练day1
  • Unity毛发系统TressFX Exporter
  • 《爆肝整理》保姆级系列教程python接口自动化(十九)--Json 数据处理---实战(详解)