当前位置: 首页 > news >正文

PyTorch 中 reshape 函数用法示例

PyTorch 中 reshape 函数用法示例

在 PyTorch 中,reshape 函数用于改变张量的形状,而不改变其中的数据。下面是一些关于 reshape 函数的常见用法示例。

基本语法

torch.reshape(input, shape)  
# input: 要重塑的张量。
# shape: 目标形状,可以是一个整数元组或列表。

示例1:将一维张量转为二维张量(重要)

import torch  # 创建一个一维张量  
tensor_1d = torch.tensor([1, 2, 3, 4, 5, 6])  # 使用 reshape 将其转为形状为 (2, 3) 的二维张量  
tensor_2d = tensor_1d.reshape(2, 3)  print(tensor_2d)

输出:

tensor([[1, 2, 3],  [4, 5, 6]])

示例 2:使用负数维度自动推导形状(重要)

在 reshape 中可以使用 -1 表示自动推导该维度的大小。

# 创建一个一维张量  
tensor_1d = torch.tensor([1, 2, 3, 4, 5, 6])  # 使用 -1 自动推导维度  
tensor_2d = tensor_1d.reshape(3, -1)  print(tensor_2d)

输出:

tensor([[1, 2],  [3, 4],  [5, 6]])

在这里,-1 的意思是由其他维度的大小推导出来的。

示例 3:将三维张量展平为二维张量

假设有一个形状为 (2, 3, 4) 的三维张量,可以将其展平为形状为 (2, 12) 的二维张量。

# 创建一个三维张量  
tensor_3d = torch.randn(2, 3, 4)  # 随机生成一个张量  
print(tensor_3d)
# 重塑为二维张量  
tensor_2d = tensor_3d.reshape(2, -1)  
print(tensor_2d)
print(tensor_2d.shape)  # 输出应该为 torch.Size([2, 12])

输出:

tensor([[[-2.0344, -0.0268,  1.4198,  0.5537],[ 2.1429, -0.8317, -1.6704,  0.3521],[ 0.4205,  0.0552,  1.8191,  0.4051]],[[-0.5695,  0.2553, -0.8192, -1.3156],[ 0.8952, -0.6411,  1.0547,  0.7071],[-0.1367, -2.2702,  0.6299, -0.7946]]])tensor([[-2.0344, -0.0268,  1.4198,  0.5537,  2.1429, -0.8317, -1.6704,  0.3521,0.4205,  0.0552,  1.8191,  0.4051],[-0.5695,  0.2553, -0.8192, -1.3156,  0.8952, -0.6411,  1.0547,  0.7071,-0.1367, -2.2702,  0.6299, -0.7946]])torch.Size([2, 12])

示例4:调换维度

如果你想把一个矩阵的行和列互换,可以先使用 reshape 将张量改变形状,再使用 .t() 方法进行转置(若适用)。

# 创建一个二维张量  
tensor_2d = torch.tensor([[1, 2, 3], [4, 5, 6]])  # 使用 reshape 先改变形状后,再用 .t() 转置  
tensor_transposed = tensor_2d.reshape(3, 2).t()  # 先变成 3x2 然后转置  print(tensor_transposed)

输出:

tensor([[1, 4],  [2, 5],  [3, 6]])

总结

  • reshape 是用于改变张量形状的工具,数据不变。
  • 可以使用 -1 进行自动推导。
  • 适用于多维张量的重塑,便于后续的数据处理和建模。
http://www.lryc.cn/news/461064.html

相关文章:

  • 安全光幕的工作原理及应用场景
  • 《深度学习》OpenCV LBPH算法人脸识别 原理及案例解析
  • 数据结构之顺序表——动态顺序表(C语言版)
  • Python 网络爬虫入门与实战
  • 成都睿明智科技有限公司电商服务可靠不?
  • fmql之Linux Uart
  • 【火山引擎】调用火山大模型的方法 | SDK安装 | 配置 | 客户端初始化 | 设置
  • 前端实现下载功能汇总(下载二进制流文件、数组下载成csv、将十六进制下载成pcap、将文件下载成zip)
  • iLogtail 开源两周年:UC 工程师分享日志查询服务建设实践案例
  • 【MySQL】入门篇—基本数据类型:NULL值的概念
  • Java设计模式10 - 观察者模式
  • LabVIEW示波器通信及应用
  • 西门子PLC中Modbus通讯DATA_ADDR通讯起始地址设置以及RTU轮询程序设计。
  • 趋势(一)利用python绘制折线图
  • 【含文档】基于Springboot+Vue的采购管理系统(含源码+数据库+lw)
  • 【C++11入门基础】
  • Pytest中fixture的scope详解
  • Springboot 接入 WebSocket 实战
  • 数据结构之红黑树的实现
  • 智能工厂的设计软件 中的AI操作系统的“三维时间”(历时/共时/等时)构建的“能力成熟度-时间规模”平面
  • Spring Boot常见错误与解决方法
  • Mac中安装以及配置adb环境
  • WebGL着色器语言中各个变量的作用
  • Canmv k230 C++案例1——image classify学习笔记 初版
  • vs2022 dump调试
  • OpenCV高级图形用户界面(11)检查是否有键盘事件发生而不阻塞当前线程函数pollKey()的使用
  • nvm安装,node多版本管理
  • ThingsBoard规则链节点:Assign To Customer节点详解
  • 自监督行为识别-时空线索解耦(论文复现)
  • MyBatisPlus:自定义SQL