当前位置: 首页 > news >正文

Pytorch基础:Tensor的squeeze和unsqueeze方法

相关阅读

Pytorch基础icon-default.png?t=N7T8https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


        在Pytorch中,squeeze和unsqueeze是Tensor的一个重要方法,同时它们也是torch模块中的一个函数,它们的语法如下所示。 

Tensor.squeeze(dim=None) → Tensor
torch.squeeze(input, dim=None) → Tensorinput (Tensor) – the input tensor.
dim (int or tuple of ints, optional) – if given, the input will be squeezed only in the specified dimensions.Tensor.unsqueeze(dim) → Tensor
torch.unsqueeze(input, dim) → Tensorinput (Tensor) – the input tensor.
dim (int) – the index at which to insert the singleton dimension

一、squeeze

        squeeze函数(或方法)返回一个新的张量,该张量移除了原张量中大小为1的维度,例如:输入张量的形状是(A×1×B×C×1×D),使用了squeeze函数(或方法)后,输出张量的形状是(A×B×C×D)。请注意:输出张量将与输入张量共享底层存储,因此改变一个张量的内容将改变另一个张量的内容。默认情况下,squeeze将移除所有尺寸为1的维度,如果传递了dim参数,则会将dim中的维度展开。dim的范围可以是[-input.dim()-1, input.dim()],其中负数索引表示从后往前数的位置,例如-1代表最后一个维度。

        可以看下面的例子以更好的理解:

import torch# 创建一个形状为 (2, 1, 2, 1, 2) 的张量
x = torch.zeros(2, 1, 2, 1, 2)
print(x, x.size(), id(x))# 移除所有大小为1的维度
a = torch.squeeze(x)  # 等价于 a = x.squeeze()
print(a, a.size(), id(a))# 尝试移除第0维度(由于第0维度大小不为1,因此不改变形状)
b = torch.squeeze(x, 0)  # 等价于 b = x.squeeze(0)
print(b, b.size(), id(b))# 移除第1维度(第1维度大小为1)
c = torch.squeeze(x, 1)  # 等价于 c = x.squeeze(1)
print(c, c.size(), id(c))# 移除第1、第2和第3维度(第1和第3维度大小为1,第2维度不变)
d = torch.squeeze(x, (1, 2, 3))  # 等价于 d = x.squeeze((1, 2, 3))
print(d, d.size(), id(d))# 验证所有张量共享底层存储空间
print(x.storage().data_ptr() == a.storage().data_ptr() == b.storage().data_ptr() == c.storage().data_ptr() == d.storage().data_ptr()) # 共享底层存储空间输出:
tensor([[[[[0., 0.]],[[0., 0.]]]],[[[[0., 0.]],[[0., 0.]]]]]) torch.Size([2, 1, 2, 1, 2]) 1899057117680tensor([[[0., 0.],[0., 0.]],[[0., 0.],[0., 0.]]]) torch.Size([2, 2, 2]) 1899057158240tensor([[[[[0., 0.]],[[0., 0.]]]],[[[[0., 0.]],[[0., 0.]]]]]) torch.Size([2, 1, 2, 1, 2]) 1899737467296tensor([[[[0., 0.]],[[0., 0.]]],[[[0., 0.]],[[0., 0.]]]]) torch.Size([2, 2, 1, 2]) 1899737467376tensor([[[0., 0.],[0., 0.]],[[0., 0.],[0., 0.]]]) torch.Size([2, 2, 2]) 1899737467216
True

二、 unsqueeze

        unsqueeze函数(或方法)函数返回一个新的张量,该张量在指定维度(dim)插入一个大小为1的维度。使用unsqueeze函数(或方法)后,输入张量的形状会相应增加一个维度。例如,输入张量的形状是(A×B×C),在第1维度使用unsqueeze后,输出张量的形状将变为(A×1×B×C)。请注意,输出张量将与输入张量共享底层存储,因此改变一个张量的内容将改变另一个张量的内容。dim的范围可以是[-input.dim(), input.dim()-1],其中负数索引表示从后往前数的位置,例如-1代表最后一个维度。

         可以看下面的例子以更好的理解:

import torch# 创建一个形状为 (2, 2, 2) 的张量
x = torch.zeros(2, 2, 2)
print(x, x.size(), id(x))# 在第0维度插入单维度
a = torch.unsqueeze(x, 0)  # 等价于 a = x.unsqueeze(0)
print(a, a.size(), id(a))# 在第1维度插入单维度
b = torch.unsqueeze(x, 1)  # 等价于 b = x.unsqueeze(1)
print(b, b.size(), id(b))# 在第2维度插入单维度
c = torch.unsqueeze(x, 2)  # 等价于 c = x.unsqueeze(2)
print(c, c.size(), id(c))# 在第3维度插入单维度
d = torch.unsqueeze(x, 3)  # 等价于 d = x.unsqueeze(3)
print(d, d.size(), id(d))# 验证所有张量共享底层存储空间
print(x.storage().data_ptr() == a.storage().data_ptr() == b.storage().data_ptr() == c.storage().data_ptr() == d.storage().data_ptr())  # 共享底层存储空间输出:
tensor([[[0., 0.],[0., 0.]],[[0., 0.],[0., 0.]]]) torch.Size([2, 2, 2]) 1509028592032tensor([[[[0., 0.],[0., 0.]],[[0., 0.],[0., 0.]]]]) torch.Size([1, 2, 2, 2]) 1509028632592tensor([[[[0., 0.],[0., 0.]]],[[[0., 0.],[0., 0.]]]]) torch.Size([2, 1, 2, 2]) 1507561225888tensor([[[[0., 0.]],[[0., 0.]]],[[[0., 0.]],[[0., 0.]]]]) torch.Size([2, 2, 1, 2]) 1507561391824tensor([[[[0.],[0.]],[[0.],[0.]]],[[[0.],[0.]],[[0.],[0.]]]]) torch.Size([2, 2, 2, 1]) 1507561391904
True
http://www.lryc.cn/news/407501.html

相关文章:

  • PHP压缩打包,下载目录或者文件,解压zip文件
  • 后端面试题日常练-day08 【Java基础】
  • Linux:core文件无法生成排查步骤
  • 大模型学习资源
  • 约定(模拟赛2 T3)
  • Java推送xml数据进行http请求
  • Docker安装 OpenResty详细教程
  • 前端位运算运用场景小知识(权限相关)
  • 【云原生】Kubernetes中的DaemonSet介绍、原理、用法及实战应用案例分析
  • 使用框架构建React Native应用程序的最佳实践
  • Godot入门 02玩家1.0版
  • Docker-Compose配置zookeeper+KaFka+CMAK简单集群
  • Python中,集合几种基本运算
  • netsuite查询货品库存
  • Java 实现分页的几种方式详解
  • vite构建vue3项目hmr生效问题踩坑记录
  • 区块链赋能民生大数据
  • 10 Vue 特性要点
  • ESP32和mDNS学习
  • 学习SQL如何使用CASE语句查询分析设备状态
  • Gartner发布2024年零信任网络技术成熟度曲线:20项零信任相关的前沿和趋势性技术
  • React hook 之 useState
  • jenkins中shell脚本中使用构建参数化Groovy变量的四种方式
  • Robot Operating System——ParameterEventHandler监控Parameters的增删改行为
  • 计算机网络(Wrong Question)
  • Docker+consul容器服务的更新与发现
  • 全网最详细!! Linux 安装、配置教程
  • cocos creator 3学习记录01——如何替换图片
  • 【Android Compose】ListView效果
  • 【Pytorch实战教程】Pytorch中.detach()的详细介绍