当前位置: 首页 > news >正文

详细分析Pytorch中的transpose基本知识(附Demo)| 对比 permute

目录

  • 前言
  • 1. 基本知识
  • 2. Demo

前言

原先的permute推荐阅读:详细分析Pytorch中的permute基本知识(附Demo)

1. 基本知识

transpose 是 PyTorch 中用于交换张量维度的函数,特别是用于二维张量(矩阵)的转置操作,常用于线性代数运算、深度学习模型的输入和输出处理等

基本知识如下

  • 功能:交换张量的两个维度
  • 输入:一个张量和两个要交换的维度的索引
  • 输出:具有新维度顺序的张量

原理分析如下:
transpose 的核心原理是通过交换指定维度的方式改变张量的形状
例如,对于一个二维张量 (m, n),调用 transpose(0, 1) 会返回一个形状为 (n, m) 的新张量,其元素顺序经过了调整

  • 高维张量: 对于高维张量,transpose 只会影响指定的两个维度,而其他维度保持不变
  • 内存视图:与 permute 类似,transpose 返回的是原始张量的一个视图,不会进行数据复制

2. Demo

示例 1: 基本用法

import torch# 创建一个 3x4 的矩阵
matrix = torch.randn(3, 4)
print("原始矩阵形状:", matrix.shape)# 使用 transpose 交换维度
# 将矩阵的维度从 (3, 4) 变为 (4, 3)
transposed_matrix = matrix.transpose(0, 1)
print("转置后矩阵形状:", transposed_matrix.shape)

截图如下:

在这里插入图片描述

示例 2: 高维张量的转置

import torch# 创建一个 2x3x4 的张量
tensor = torch.randn(2, 3, 4)
print("原始张量形状:", tensor.shape)# 使用 transpose 交换第二和第三维
# 将张量的维度从 (2, 3, 4) 变为 (2, 4, 3)
transposed_tensor = tensor.transpose(1, 2)
print("转置后张量形状:", transposed_tensor.shape)

截图如下:

在这里插入图片描述

示例 3: 在深度学习中的应用

import torch# 创建一个假设的批量数据 (批量, 高度, 宽度, 通道)
batch_tensor = torch.randn(5, 256, 256, 3)
print("原始批量形状:", batch_tensor.shape)# 将通道和宽度维度交换
# 适用于某些模型的输入
batch_transposed = batch_tensor.transpose(2, 3)
print("转置后批量形状:", batch_transposed.shape)

截图如下:

在这里插入图片描述

基本的注意事项如下:

  • 只支持交换两个维度: transpose 只能同时交换两个维度,而无法一次性处理多个维度
  • 数据不复制:返回的是原始张量的视图,因此内存开销较小
  • 维度索引:确保指定的维度索引在张量的维度范围内,否则会引发错误
http://www.lryc.cn/news/472191.html

相关文章:

  • 初识WebGL
  • 【力扣】Go语言回溯算法详细实现与方法论提炼
  • 「C/C++」C/C++ 之 第三方库使用规范
  • 六、元素应用CSS的习题
  • 正式入驻!上海斯歌BPM PaaS管理软件等产品入选华为云联营商品
  • 使用 Axios 上传大文件分片上传
  • Nginx+Lua脚本+Redis 实现自动封禁访问频率过高IP
  • PART 1 数据挖掘概论 — 数据挖掘方法论
  • Centos安装ffmpeg的方法
  • 理解SQL中通配符的使用
  • SpringBoot篇(简化操作的原理)
  • Cesium的模型(ModelVS)顶点着色器浅析
  • 机器人领域中的scaling law:通过复现斯坦福机器人UMI——探讨数据规模化定律(含UMI的复现关键)
  • C++之多态的深度剖析
  • Microsoft Office PowerPoint制作科研论文用图
  • go语言进阶之并发基础
  • po、dto、vo的使用场景
  • 聊一聊Elasticsearch的一些基本信息
  • Unity 两篇文章熟悉所有编辑器拓展关键类 (上)
  • Spring SPI、Solon SPI 有点儿像(Maven 与 Gradle)
  • 合并排序算法(C语言版)
  • C++——输入一行文字,找出其中的大写字母、小写字母、空格数字以及其他字符各有多少。用指针或引用方法处理。
  • 【skywalking】maximum query complexity exceeded 3336 > 3000
  • 开源一个开发的聊天应用与AI开发框架,集成 ChatGPT,支持私有部署的源码
  • 开发了一个成人学位英语助考微信小程序
  • LeetCode16:最接近的三数之和
  • VisualStudio2022配置2D图形库SFML
  • 「Mac畅玩鸿蒙与硬件4」鸿蒙开发环境配置篇4 - DevEco Studio 高效使用技巧
  • 构建生产级的 RAG 系统
  • 完全透彻了解一个asp.net core MVC项目模板2