当前位置: 首页 > news >正文

pytorch中的各种计算

  对tensor矩阵的维度变换,加减乘除等是深度学习中的常用操作,本文对一些常用方法进行总结

矩阵乘法

  混合矩阵相乘,官网

torch.matmul(input, other, *, out=None) → Tensor

  这个方法执行矩阵相乘操作,需要第一个矩阵的最后一个维度和第二个矩阵的第一个维度相同,即:假设我们有两个矩阵 A 和 B,它们的 size 分别为 (m, n) 和 (n, p),那么 A x B 的 size 为 (m, p)。
  矩阵点乘,官网

torch.mul(input, other, *, out=None) → Tensor

  这个方法对矩阵做点积运算(也可简写为*),这个方法要求第一个矩阵的第一个维度和第二个矩阵的第一个维度对应。torch.dot()类似于mul(),它是向量(即只能是一维的张量)的对应位相乘再求和,返回一个tensor。

矩阵维度变换

  tensor.view方法,用于调整矩阵的维度,这个方法要求矩阵在调整为度前后的元素个数必须是相同的,官网,例子:

>>> t = torch.rand(4, 4)
>>> b = t.view(2, 8)
>>> t.storage().data_ptr() == b.storage().data_ptr()  # `t` and `b` share the same underlying data.
True
# Modifying view tensor changes base tensor as well.
>>> b[0][0] = 3.14
>>> t[0][0]
tensor(3.14)

  torch中对矩阵的压缩和解压操作:torch.squeeze和torch.unsqueeze,这两种方法的作用是压缩矩阵中的某一个维度或者增加一个维度,官网,两种方法的详解可以参考我之前的笔记pytorch中的torch.squeeze和torch.unsqueeze。
  矩阵填充,官网torch.nn.functional.pad

torch.nn.functional.pad(input, pad, mode='constant', value=None) → Tensor
Args:"""input:四维或者五维的tensor Variabepad:不同Tensor的填充方式1.四维Tensor:传入四元素tuple(pad_l, pad_r, pad_t, pad_b),指的是(左填充,右填充,上填充,下填充),其数值代表填充次数2.六维Tensor:传入六元素tuple(pleft, pright, ptop, pbottom, pfront, pback),指的是(左填充,右填充,上填充,下填充,前填充,后填充),其数值代表填充次数mode: ’constant‘, ‘reflect’ or ‘replicate’三种模式,指的是常量,反射,复制三种模式value:填充的数值,在"contant"模式下默认填充0,mode="reflect" or "replicate"时没有			

  如果给入的填充次数是负数,该函数可以实现从该方向对矩阵的裁剪操作。
  需要注意的是,本文中提到的所有方法都支持broadcast操作,也就是,除了参与操作的最后两个维度(矩阵),前面的所有维度都会被认为是batch,以torch,matmul为例,该方法使用两个tensor的后两个维度来计算,其他的维度都可以认为是batch。假设两个输入的维度分别是 i n p u t ( 1000 × 500 × 99 × 11 ) input(1000×500×99×11) input(1000×500×99×11), o t h e r ( 500 × 11 × 99 ) other(500×11×99) other(500×11×99),那么我们可以认为 t o r c h . m a t m u l ( i n p u t , o t h e r ) torch.matmul(input,other) torch.matmul(input,other) 首先是进行后两位矩阵乘法得到 ( 99 × 99 ) (99×99) (99×99) ,然后分析两个参数的batch size分别是 ( 1000 × 500 ) (1000×500) (1000×500) ( 500 ) (500) (500), 可以广播成为 ( 1000 × 500 ) (1000×500) (1000×500),因此最终输出的维度是 ( 1000 × 500 × 99 × 99 ) (1000×500×99×99) (1000×500×99×99)

http://www.lryc.cn/news/307106.html

相关文章:

  • 大数据技术之 Kafka
  • 【GB28181】wvp-GB28181-pro部署安装教程(Ubuntu平台)
  • CentOS删除除了最近5个JAR程序外的所有指定Java程序
  • 面试redis篇-13Redis为什么那么快
  • python Matplotlib Tkinter--pack 框架案例
  • 连接未来:嵌入式系统在物联网时代的应用
  • 自动驾驶中的障碍物时间对齐法
  • 介绍 PIL+IPython.display+mtcnn for 音视频读取、标注
  • C语言中strstr函数的使用!
  • Vue项目中,src目录下的vue.app文件介绍
  • 【Android】坐标系
  • OSCP靶场--Slort
  • 大数据职业技术培训包含哪些
  • 【Java程序设计】【C00313】基于Springboot的物业管理系统(有论文)
  • TensorFlow训练大模型做AI绘图,需要多少的GPU算力支撑
  • docker创建mongodb数据库容器
  • Python并发编程:多线程-线程理论
  • 自定义Chrome的浏览器开发者工具DevTools界面的字体和样式
  • 人事|人事管理系统|基于Springboot的人事管理系统设计与实现(源码+数据库+文档)
  • React18源码: Fiber树中的优先级与帧栈模型
  • Hive 最全面试题及答案(基础篇)
  • 【力扣】整数反转,判断是否溢出的数学解法
  • Jmeter之内置函数__property和__P的区别
  • GPT润色指令
  • Ubuntu中matplotlib显示中文的方法
  • String类-equals和==的区别-遍历-SubString()-StringBuilder-StringJoiner-打乱字符串
  • IDEA的LeetCode插件的设置
  • 2024.2.29 模拟实现 RabbitMQ —— 项目展示
  • React htmlfor
  • 现代化数据架构升级:毫末智行自动驾驶如何应对年增20PB的数据规模挑战?