当前位置: 首页 > news >正文

pytorch 线性层Linear详解

线性层就是全连接层,以一个输入特征数为2,输出特征数为3的线性层为例,其网络结构如下图所示:
在这里插入图片描述
输入输出数据的关系如下:
在这里插入图片描述
写成矩阵的形式就是:
在这里插入图片描述
下面通过代码进行验证:

import torch.nn as nn
linear_layer = nn.Linear(2,3)
print('weight shape',linear_layer.weight.shape,'bias shape',linear_layer.bias.shape)
linear_layer.weight.data = torch.tensor([[1.,1.],[2.,2.],[3.,3.]])
linear_layer.bias.data = torch.tensor([1.,2.,3.])
x=torch.tensor([[1.,2.]]) # 输入张量应该是浮点数,否则报错
y=linear_layer(x)
print(y)

输出结果:
在这里插入图片描述
与期望的结果相符
torch.nn中的Linear类包含的变量weight就是上面定义的权重矩阵A,包含的变量bias是一个一维的张量。
线性层输入输出关系的矩阵形式中之所以将输入数据和输出数据写成行向量的形式,是因为Linear层可以自动处理多条输入数据的情况。如果输入的张量是一个二维张量,有N行,每一行代表一条数据,不同列代表一条数据的不同特征,那么Linear层会自动将每一条数据分别作为全连接层的输入,得到N条输出数据,Linear的返回值是一个shape为(N,out_features)的张量。例如:

x2=torch.tensor([[1.,2],[2,4]])
y2=linear_layer(x2)
print(y2)

这里有两条输入数据(1,2)和(2,4),输出数据也是两条。输出结果如下:
在这里插入图片描述

http://www.lryc.cn/news/132280.html

相关文章:

  • LeetCode 833. 字符串中的查找与替换
  • Oracle故障案例之-19C时区补丁DSTV38更新
  • 设计模式之组合模式(Composite)的C++实现
  • mongo的include方法踩坑
  • 阿里云无影云电脑/云桌面收费价格表_使用申请方法
  • jvm内存溢出排查(使用idea自带的内存泄漏分析工具)
  • JS内存泄漏
  • 线程和进程同步互斥你真的掌握了吗?(同步互斥机制保姆级讲解与应用)
  • Android 9.0 Vold挂载流程解析(上)
  • 界面组件Telerik UI for WinForms R2 2023——拥有VS2022暗黑主题
  • vue+elementui 实现文本超出长度显示省略号,鼠标移上悬浮展示全部内容
  • 【STM32RT-Thread零基础入门】 5. 线程创建应用(线程创建、删除、初始化、脱离、启动、睡眠)
  • 计算机竞赛 python+深度学习+opencv实现植物识别算法系统
  • 深度探索ChatGPT:如何进行专业提问以获取精确答案
  • 1.vue3+vite开发中axios使用及跨域问题解决
  • 【LangChain】P1 LangChain 应用程序的核心构建模块 LLMChain 以及其三大部分
  • 关于查看处理端口号和进程[linux]
  • C 语言的 strcat() 函数和 strncat() 函数
  • C++ string 的用法
  • MyBatis-Flex学习记录1---请各位大神指教
  • 二分查找旋转数组
  • 关于3D位姿旋转
  • 解锁项目成功的关键:项目经理的结构化思维之道
  • 力扣974被K整除的子数组
  • 简单认识Docker数据管理
  • UDP数据报结构分析(面试重点)
  • 【Java 动态数据统计图】动态数据统计思路案例(动态,排序,数组)二(113)
  • C++进阶 类型转换
  • Idea中隐藏指定文件或指定类型文件
  • 第2步---MySQL卸载和图形化工具展示