当前位置: 首页 > news >正文

PyTorch如何通过 torch.unbind 和torch.stack动态调整张量的维度顺序

笔者一篇博客PyTorch 的 torch.unbind 函数详解与进阶应用:中英双语中有一个例子如下:

# 创建一个 3x2x2 的三维张量
x = torch.tensor([[[1, 2], [3, 4]],[[5, 6], [7, 8]],[[9, 10], [11, 12]]])# 第一步:沿第 0 维分解为 32x2 张量
unbind_result = torch.unbind(x, dim=0)# 第二步:沿第 2 维重新堆叠
stack_result = torch.stack(unbind_result, dim=2)
print("最终结果:", stack_result)

结果

最终结果:
tensor([[[ 1,  5,  9],[ 3,  7, 11]],[[ 2,  6, 10],[ 4,  8, 12]]])
  • 使用 torch.unbind 沿第 0 维分解。
  • 使用 torch.stack 沿第 2 维重新组合,从而完成了维度转换。

张量的形状在每一步的变化如下:

  • 原始张量形状为 [3, 2, 2]。
  • 分解后,得到 3 个形状为 [2, 2] 的张量。
  • 堆叠时,将这些张量沿新的维度 dim=2 组合,最终形状变为 [2, 2, 3]。

通过这种分解和堆叠方式,我们可以灵活地操作张量的维度和数据布局。

具体是怎么变的,这里记录一下。

这个例子展示了如何通过 torch.unbindtorch.stack 动态调整张量的维度顺序。以下是对这个例子的详细解释,包括每一步的操作和张量形状变化:


1. 初始张量

我们先创建一个形状为 [3, 2, 2] 的张量 x

x = torch.tensor([[[1, 2], [3, 4]],[[5, 6], [7, 8]],[[9, 10], [11, 12]]])

张量的内容

x = [[[1, 2],  [3, 4]],    # 第一个“平面”[[5, 6],  [7, 8]],    # 第二个“平面”[[9, 10], [11, 12]]   # 第三个“平面”]

形状[3, 2, 2]
这里的含义:

  • 第一维度(dim=0,大小为3):有3个“平面”(或者块)。
  • 第二维度(dim=1,大小为2):每个“平面”有两行。
  • 第三维度(dim=2,大小为2):每行有两个元素。

2. 使用 torch.unbind 沿 dim=0 分解

unbind_result = torch.unbind(x, dim=0)

torch.unbind 的作用是沿着指定的维度(这里是 dim=0)移除这一维度,并返回一个元组,元组中的每个元素都是输入张量在该维度上的切片。

对于我们的例子:

  • x 沿着 dim=0 分解,相当于把张量按“平面”切开。
  • 原始的 3×2×2 张量被分成了 3 个形状为 [2, 2] 的子张量。

unbind_result 的内容

unbind_result = (tensor([[1, 2],  [3, 4]]),  # 第一个平面tensor([[5, 6],  [7, 8]]),  # 第二个平面tensor([[9, 10], [11, 12]]) # 第三个平面
)

每个切片都是一个形状为 [2, 2] 的二维张量。
这里的维度变化:

  • 原始张量形状 [3, 2, 2] → 切片形状 [2, 2]

3. 使用 torch.stack 沿 dim=2 重新组合

stack_result = torch.stack(unbind_result, dim=2)

torch.stack 的作用是把一组张量沿着新的维度拼接起来。这里:

  • unbind_result 是一个包含 3 个 [2, 2] 张量的元组。
  • 我们指定 dim=2,意思是在原始张量的最后一维(第三维)增加一个新的维度来进行拼接。
拼接过程
  1. 第一个子张量的每个位置与第二个、第三个子张量的对应位置对齐,按列方向拼接。
  2. 拼接后,原来 [2, 2] 的子张量变成了 [2, 3] 的子张量。

举例说明:

  • 原始三个 [2, 2] 的张量:
    tensor([[1, 2], [3, 4]])
    tensor([[5, 6], [7, 8]])
    tensor([[9, 10], [11, 12]])
    
  • 沿 dim=2 进行拼接后:
    [[[1, 5, 9], [3, 7, 11]],  # 第一行拼接[[2, 6, 10], [4, 8, 12]]  # 第二行拼接
    ]
    

最终结果

stack_result = tensor([[[ 1,  5,  9], [ 3,  7, 11]],[[ 2,  6, 10], [ 4,  8, 12]]
])

形状变化

  • 原始张量 [3, 2, 2] → 分解后的切片 [2, 2] → 拼接后的结果 [2, 2, 3]

4. 形状变化总结

操作张量内容张量形状
初始张量x[3, 2, 2]
使用 torch.unbind(dim=0)3 个 [2, 2] 的子张量[2, 2]
使用 torch.stack(dim=2)拼接为一个新的张量[2, 2, 3]

5. 为什么维度顺序调整了?

通过 torch.unbindtorch.stack 的组合,实际上我们重新定义了张量的组织方式:

  1. torch.unbinddim=0 的维度移除,分解成多个子张量。
  2. torch.stack 指定新的维度(这里是 dim=2),将这些子张量拼接为一个新维度,从而实现了维度的重新排列。

最终,我们将原来的“平面”维度(dim=0)转移到了列方向(dim=2),实现了动态调整维度顺序的效果。


6. 总结

  • torch.unbind 用于移除一个维度并分解张量
  • torch.stack 用于沿指定的新维度拼接张量
  • 两者结合可以灵活调整张量的维度顺序。

这个例子展示了如何从 [3, 2, 2] 变换到 [2, 2, 3],过程中分解和拼接操作相辅相成,适用于需要动态调整张量维度的高级场景。

后记

2024年12月12日22点28分于上海,基于GPT4o大模型生成。

http://www.lryc.cn/news/504692.html

相关文章:

  • 【Unity3D】报错libil2cpp.so找不到问题
  • 事件冒泡机制详解
  • 红米Note 9 Pro5G刷LineageOS
  • 6.3.1 MR实战:计算总分与平均分
  • ARM循环程序和子程序设计
  • 静态路由、RIP、OSPF、BGP的区别
  • 知识分享第二十八天-数学篇一
  • BigDecimal在进行除法运算时需要注意四舍五入的位置
  • 第二部分:进阶主题 14 . 性能优化 --[MySQL轻松入门教程]
  • Mac电脑设置鼠标的滚轮方向
  • 【LDAP】LDAP概念和原理介绍
  • Android系统(android app和系统架构)
  • Android HandlerThread、Looper、MessageQueue 源码分析
  • HTML知识点详解教程
  • [数据结构#1] 并查集 | FindRoot | Union | 优化 | 应用
  • 科研绘图系列:R语言绘制网络图和密度分布图(network density plot)
  • Linux中输入和输出基本过程
  • 使用 acme.sh 签发和自动续期 ssl https 证书
  • spring重点面试题总结
  • 新的一章:codegeex
  • 游戏引擎学习第50天
  • 快速理解类的加载过程
  • 医院跌倒检测识别 使用YOLO,COCO ,VOC格式对4806张原始图片进行标注,可识别病人跌倒,病人的危险行为,病床等场景,预测准确率可达96.7%
  • [Unity Shader] 【游戏开发】【图形渲染】Unity Shader的种类2-顶点/片元着色器与固定函数着色器的选择与应用
  • 浏览器端的 js 包括哪几个部分
  • GoogLeNet网络:深度学习领域的创新之作
  • 深入C语言文件操作:从库函数到系统调用
  • Java序列化
  • 基坑表面位移沉降倾斜自动化监测 非接触式一体化解决机器视觉
  • 提升效率:精通Windows命令行的艺术