当前位置: 首页 > news >正文

无脑入门pytorch系列(三)—— nn.Linear

本系列教程适用于没有任何pytorch的同学(简单的python语法还是要的),从代码的表层出发挖掘代码的深层含义,理解具体的意思和内涵。pytorch的很多函数看着非常简单,但是其中包含了很多内容,不了解其中的意思就只能【看懂代码】,无法【理解代码】。

目录

  • 官方定义
  • demo1
  • demo2

官方定义

nn.Linear 是 PyTorch 中用于创建线性层的类。线性层也被称为全连接层,它将输入与权重矩阵相乘并加上偏置,然后通过激活函数进行非线性变换。

官方的文档如下,torch.nn.Linear:

image-20230811162434032

demo1

下面是一个官方文档给出的例子:

m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)
print(output.size())

输出的结果:

image-20230811162456557

首先,输出[128, 20]的张量,经过一个[20, 30]的线性层,变成[128, 30]的张量。
可以理解为矩阵的乘法,也就是矩阵的"外积",矩阵的叉乘,第一个矩阵的行数与第二个矩阵的列数相同。

demo2

input_data = torch.Tensor([[1, 2, 3], [4, 5, 6]]) # [2, 3] 
m = nn.Linear(3, 2)
output = m(input_data)
print(output) # [2, 2]

输出:

image-20230811162529522

可以看看nn.Linear(3, 2)的参数:

for param in m.parameters():print(param)

输出:

image-20230811162606663

结合参数,其实本身它们的计算就是矩阵的乘法:

image-20230811162640687

输入X为[n, i]的矩阵,经过W为[i,0]的矩阵,加上b的偏置得到Y为[n,o]的矩阵。

计算的思路也比较简单:
output[0][0] = [1, 2, 3] * [0.2888, -0.4596, -0,4896] + 0.3740 = -1.7253
output[0][1] = [1, 2, 3] * [0.4730, -0.4033, -0.4739] + 0.3182 = -1.4370
output[1][0] = [4, 5, 6] * [0.2888, -0.4596, -0,4896] + 0.3740 = -3.7066
output[1][1] = [4, 5, 6] * [0.4730, -0.4033, -0.4739] + 0.3182 = -2.6495

通过input和param的对比,我们可以很轻松地理解实际上就是矩阵的乘法操作。而模型在训练过程中就是不断调整param的参数使得输出的张量符合训练集的需求。

http://www.lryc.cn/news/125833.html

相关文章:

  • SQL Server用sql语句添加列,添加列注释
  • springBoot中service层查询使用多线程CompletableFuture(有返回值)
  • 畜牧虚拟仿真 | 鱼授精过程VR模拟演练系统
  • 第一百一十四回 局部动态列表
  • 多尺度目标检测【动手学深度学习】
  • elasticsearch 基础
  • 【BUG】docker安装nacos,浏览器却无法访问到页面
  • C#引用Web Service 类型方法,添加搜索本地服务器Web Service 接口调用方法
  • yolov8训练进阶:新增配置参数
  • 轻量级自动化测试框架WebZ
  • 如何实现安全上网
  • Redis心跳检测
  • 【数据库】Sql Server可视化工具SSMS条件和SQL窗格以及版本信息
  • Python SFTP 详细使用
  • MyBatis的XML映射文件
  • UML-类图和对象图
  • 升级指定版本Node.js或npm
  • UE4/5 GAS技能系统入门3 - GameplayEffect
  • Linux交叉编译opencv并移植ARM端
  • TypeScript教程(一)简介与安装
  • 做视频_Style
  • vue3使用pinia和pinia-plugin-persist做持久化存储
  • 数据结构入门指南:二叉树
  • 大数据课程J2——Scala的基础语法和函数
  • 03-基础入门-搭建安全拓展
  • 穿越未来:探索虚拟现实科技的未来前景
  • SQL- 每日一题【1327. 列出指定时间段内所有的下单产品】
  • [xgb] plot tree
  • 【云原生】Kubernetes 概述
  • 9.2.2Socket(TCP)