当前位置: 首页 > news >正文

Pytorch个人学习记录总结 07

目录

神经网络-非线性激活

神经网络-线形层及其他层介绍 


神经网络-非线性激活

官方文档地址:torch.nn — PyTorch 2.0 documentation 

常用的:Sigmoid、ReLU、LeakyReLU等。

 

作用:为模型引入非线性特征,这样才能在训练过程中训练出符合更多特征的模型。

其中有个参数是inplace,默认为False,表示是否就地改变输入值,True则表示直接改变了input不再有另外的返回值;False则没有直接改变input并有返回值(建议是inplace=False)。

import torch
from torch import nninput = torch.tensor([[3, -1],[-0.5, 1]])
input = torch.reshape(input, (1, 1, 2, 2))relu = nn.ReLU()
input_relu = relu(input)print('input={}\ninput_relu:{}'.format(input, input_relu))# input=tensor([[[[ 3.0000, -1.0000],
#           [-0.5000,  1.0000]]]])
# input_relu:tensor([[[[3., 0.],
#           [0., 1.]]]])

神经网络-线形层及其他层介绍 

Linear Layers中的torch.nn.Linear(in_features, out_features, bias=True)。默认bias=True。对传入数据应用线性变换

Parameters

  • in_features – size of each input sample(每个输入样本的大小)
  • out_features – size of each output sample(每个输出样本的大小)
  • bias – If set to False, the layer will not learn an additive bias. Default: True(如果为False,则该层不会学习加法偏置,默认为true)

Shape:分别关注输入、输出的最后一个维度的大小,在训练过程中,nn.Linear往往是当作的展平为一维后最后几步的全连接层,所以此时就只关注了通道数,即往往Input和Outputs是一维的)

“展平为一维”经常用到torch.nn.Flatten(start_dim=1, end_dim=- 1)

想说一下start_dim,它表示“从start_dim开始把后面的维度都展平到同一维度上”,默认是是1,在实际训练中从start_dim=1开始展平,因为在训练中的tensor是4维的,分别是[batch_size, C, H, W],而第0维的batch_size不能动它,所以是从1开始的。

还比较重要的有:torch.nn.BatchNorm2d、torch.nn.Dropout、Loss Functions(之后再讲)。其它的Transformer Layers、Recurrent Layers都不是很常用。

import torch# 对4维tensor展平,start_dim=1input = torch.arange(54)
input = torch.reshape(input, (2, 3, 3, 3))y_0 = torch.flatten(input)
y_1 = torch.flatten(input, start_dim=1)print(input.shape)
print(y_0.shape)
print(y_1.shape)# torch.Size([2, 3, 3, 3])
# torch.Size([54])
# torch.Size([2, 27])

http://www.lryc.cn/news/95646.html

相关文章:

  • vue3+ts+elementui-plus二次封装树形表格
  • 机器学习/深度学习常见算法实现(秋招版)
  • 京东技术专家首推:Spring 微服务架构设计,GitHub 星标 128K
  • R语言--森林图制作
  • Tomcat中利用war包部署
  • [JAVAee]线程安全
  • ELK环境搭建——概况
  • 面试知识点整理
  • 腾讯云服务器CVM计算型c6/c5实例CPU型号、处理器主频大全
  • vue3笔记-脚手架篇
  • 数字的补数
  • Taskfile demo
  • MyBatis学习笔记之高级映射及延迟加载
  • 小程序如何删除/上架/下架商品
  • Failed to load local font resource:微信小程序加载第三方字体
  • 使用fastjson错误
  • 【GitOps系列】使用Kustomize和Helm定义应用配置
  • Android kotlin高阶函数与Java lambda表达式介绍与实战
  • 自然语言处理实战项目13-基于GRU模型与NER的关键词抽取模型训练全流程
  • 7.26 Qt
  • 【MySQL】库和表的操作
  • (五)RabbitMQ-进阶 死信队列、延迟队列、防丢失机制
  • windows下面的python配置
  • vue3中 状态管理pinia得使用
  • 如何使用 After Effects 导出摄像机跟踪数据到 3ds Max
  • 【iOS】懒加载
  • 《脱离“一支笔、一双手、一道力扣”困境的秘诀》:突破LeetCode难题的五个关键步骤
  • 基于jeecg-boot的任务甘特图显示
  • docker export,import后无法运行,如java命令找不到,运行后容器内编码有问题
  • Web3教程| 什么是地址监控?如何使用地址监控追踪黑客地址?