当前位置: 首页 > news >正文

PyTorch常用Tensor形状变换函数详解

PyTorch常用Tensor形状变换函数详解

在PyTorch中,对张量(Tensor)进行形状变换是深度学习模型构建中不可或缺的一环。无论是为了匹配网络层的输入要求,还是为了进行数据预处理和维度调整,都需要灵活运用各种形状变换函数。本文将系统介绍几个核心的形状变换函数,并深入剖析它们的用法区别与关键点。

一、改变形状与元素数量:view()reshape()

view()reshape() 是最常用的两个用于重塑张量形状的函数。它们都可以改变张量的维度,但前提是新旧张量的元素总数必须保持一致。尽管功能相似,但它们在工作机制上存在关键差异。

view()

view() 函数返回一个具有新形状的张量,这个新张量与原始张量 共享底层数据。这意味着修改其中一个张量的数据,另一个也会随之改变。

关键点:

  • 内存共享view() 保证返回的张量与原张量共享数据,不会创建新的内存副本,因此效率很高。
  • 连续性要求view() 只能作用于在内存中 连续 (contiguous) 的张量。对于一个非连续的张量(例如通过 transpose 操作后得到的张量),直接使用 view() 会引发错误。 在这种情况下,需要先调用 .contiguous() 方法将其变为连续的,然后再使用 view()
reshape()

reshape() 函数同样用于改变张量的形状,但它更加灵活和安全。

关键点:

  • 智能处理reshape() 可以处理连续和非连续的张量。
  • 视图或副本:当作用于连续张量时,reshape() 的行为类似于 view(),返回一个共享数据的视图。 然而,当作用于非连续张量时,reshape() 会创建一个新的、具有所需形状的连续张量,并复制原始数据,此时返回的是一个副本,与原张量不再共享内存。
  • 不确定性reshape() 的语义是它 可能 会也 可能不会 共享存储空间,事先无法确定。
用法区别与选择
特性view()reshape()
内存共享总是共享(返回视图)可能共享(视图),也可能不共享(副本)
对非连续张量抛出错误自动创建副本
推荐使用场景当确定张量是连续的,并且需要保证内存共享以提升效率时。当不确定张量的连续性,或希望代码更健干时,reshape() 是更安全的选择。
使用示例
import torch# 创建一个连续的张量
x = torch.arange(12)  # tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])# 使用 view() 和 reshape()
x_view = x.view(3, 4)
x_reshape = x.reshape(3, 4)print("Original is contiguous:", x.is_contiguous()) # True
print("x_view:\n", x_view)
print("x_reshape:\n", x_reshape)# 创建一个非连续的张量
y = torch.arange(12).reshape(3, 4).t() # .t() 是 transpose(0, 1) 的简写
print("\nOriginal y is contiguous:", y.is_contiguous()) # False# 对非连续张量使用 reshape() - 成功
y_reshape = y.reshape(3, 4)
print("y_reshape:\n", y_reshape)
print("y_reshape is contiguous:", y_reshape.is_contiguous()) # True# 对非连续张量使用 view() - 报错
try:y_view = y.view(3, 4)
except RuntimeError as e:print("\nError with view():", e)# 先用 .contiguous() 再用 view() - 成功
y_view_contiguous = y.contiguous().view(3, 4)
print("y_view_contiguous:\n", y_view_contiguous)

二、交换维度:transpose()permute()

reshapeview 不同,transposepermute 用于重新排列张量的维度,而不是像“拉伸”或“压缩”数据那样改变形状。

transpose()

transpose() 函数专门用于 交换 张量的两个指定维度。

用法: tensor.transpose(dim0, dim1)

关键点:

  • 两两交换:每次只能交换两个维度。
  • 共享数据:返回的张量与原张量共享底层数据,但通常会导致张量在内存中变为非连续。
permute()

permute() 函数则提供了更强大的维度重排能力,可以一次性对所有维度进行任意顺序的重新排列。

用法: tensor.permute(dims),其中 dims 是一个包含所有原始维度索引的新顺序。

关键点:

  • 任意重排:必须为所有维度提供新的顺序。
  • 通用性transpose(dim0, dim1) 可以看作是 permute 的一个特例。
  • 共享数据与非连续性:同样地,permute 返回的也是一个共享数据的视图,并且通常会使张量变为非连续。
用法区别与选择
  • 当只需要交换两个维度时,使用 transpose() 更直观。
  • 当需要进行更复杂的维度重排,例如将 (B, C, H, W) 变为 (B, H, W, C) 时,必须使用 permute()

重要提示:由于 transpose()permute() 经常产生非连续的张量,如果后续需要使用 view(),必须先调用 .contiguous() 方法。

使用示例
import torch# 假设张量形状为 (batch, channel, height, width)
x = torch.randn(2, 3, 4, 5) # Shape: [2, 3, 4, 5]# 使用 transpose() 交换 height 和 width 维度
# 原始维度: 0, 1, 2, 3 -> 交换维度 2 和 3
x_transposed = x.transpose(2, 3)
print("Original shape:", x.shape) # torch.Size([2, 3, 4, 5])
print("Transposed shape:", x_transposed.shape) # torch.Size([2, 3, 5, 4])# 使用 permute() 将 (B, C, H, W) 变为 (B, H, W, C)
# 原始维度: 0, 1, 2, 3 -> 新维度顺序: 0, 2, 3, 1
x_permuted = x.permute(0, 2, 3, 1)
print("Permuted shape:", x_permuted.shape) # torch.Size([2, 4, 5, 3])

三、增减维度:unsqueeze()squeeze()

这两个函数用于添加或移除长度为 1 的维度,这在处理批处理数据或需要广播时非常有用。

unsqueeze()

unsqueeze() 用于在指定位置 添加 一个长度为 1 的维度。

用法: tensor.unsqueeze(dim)

关键点:

  • 它会在 dim 参数指定的位置插入一个新维度。
  • 常用于为单个样本数据添加 batch 维度,或为二维张量添加 channel 维度,以符合模型的输入格式。
squeeze()

squeeze() 用于 移除 所有长度为 1 的维度。

用法:

  • tensor.squeeze(): 移除所有长度为 1 的维度。
  • tensor.squeeze(dim): 只在指定 dim 位置移除长度为 1 的维度,如果该维度长度不为 1,则张量不变。

关键点:

  • 这是一个降维操作,可以方便地去除多余的、长度为1的维度。
  • unsqueeze()squeeze() 互为逆操作。
  • 返回的张量同样与原张量共享数据。
使用示例
import torch# 创建一个形状为 (3, 4) 的张量
x = torch.randn(3, 4)
print("Original shape:", x.shape) # torch.Size([3, 4])# 使用 unsqueeze() 在第 0 维添加 batch 维度
x_unsqueezed_0 = x.unsqueeze(0)
print("Unsqueeze at dim 0:", x_unsqueezed_0.shape) # torch.Size([1, 3, 4])# 使用 unsqueeze() 在第 1 维添加 channel 维度
x_unsqueezed_1 = x.unsqueeze(1)
print("Unsqueeze at dim 1:", x_unsqueezed_1.shape) # torch.Size([3, 1, 4])# --- squeeze ---
y = torch.randn(1, 3, 1, 4)
print("\nOriginal y shape:", y.shape) # torch.Size([1, 3, 1, 4])# 使用 squeeze() 移除所有长度为 1 的维度
y_squeezed_all = y.squeeze()
print("Squeeze all ones:", y_squeezed_all.shape) # torch.Size([3, 4])# 使用 squeeze(dim) 只移除指定位置的维度
y_squeezed_dim = y.squeeze(0) # 移除第 0 维
print("Squeeze at dim 0:", y_squeezed_dim.shape) # torch.Size([3, 1, 4])

四、view/reshapesqueeze/unsqueeze 的关系:可替代性讨论

从最终的形状结果来看,viewreshape 在很多情况下确实可以实现与 squeezeunsqueeze 相同的效果。然而,它们的设计理念和使用场景存在显著差异,一般不建议混用。

使用 reshape 替代 unsqueeze

将一个形状为 (A, B) 的张量通过 unsqueeze(0) 变为 (1, A, B),可以等价地使用 reshape(1, A, B) 来实现。

关键区别:

  • 可读性与意图unsqueeze(dim) 的意图非常明确——在指定位置插入一个新维度。这使得代码更易于理解。而 reshape() 需要提供完整的最终形状,阅读者需要通过对比新旧形状才能理解其操作意图。
  • 便利性:使用 unsqueeze 时,无需知道张量的其他维度尺寸。而使用 reshape,则必须知道所有维度的大小才能构建新的形状参数。
使用 reshape 替代 squeeze

将一个形状为 (1, A, B, 1) 的张量通过 squeeze() 变为 (A, B),可以等价地使用 reshape(A, B) 实现。

关键区别:

  • 自动化与便利性squeeze() 的核心优势在于其自动化。它会自动移除所有大小为 1 的维度,使用者无需预先知道哪些维度是 1。如果使用 reshape,则必须手动计算出目标形状,这在处理动态或未知的输入形状时会非常繁琐。
  • 条件性操作squeeze(dim) 只在指定维度大小为 1 时才执行操作,否则张量保持不变。reshape 不具备这种条件判断能力,它会强制改变形状,如果元素总数不匹配则会报错。

虽然 reshape 功能更强大,理论上可以模拟 squeezeunsqueeze 的操作,但强烈建议使用专用的函数

  • 当你的意图是添加或移除单个维度时,请使用 unsqueezesqueeze。这不仅使代码更清晰、更具可读性,还能利用 squeeze 的自动检测和条件操作特性,让代码更健壮。
  • 只有当你需要进行更复杂的、非增减单一维度的形状重塑时,才应使用 reshapeview
使用示例
import torch# --- unsqueeze vs reshape ---
x = torch.randn(3, 4) # Shape: [3, 4]# 目标: 添加 batch 维度 -> (1, 3, 4)
x_unsqueezed = x.unsqueeze(0)
x_reshaped = x.reshape(1, 3, 4)print("Unsqueeze result:", x_unsqueezed.shape) # torch.Size([1, 3, 4])
print("Reshape result:", x_reshaped.shape)   # torch.Size([1, 3, 4])
# 结果相同,但 unsqueeze(0) 意图更明确# --- squeeze vs reshape ---
y = torch.randn(1, 3, 1, 4) # Shape: [1, 3, 1, 4]# 目标: 移除所有大小为1的维度 -> (3, 4)
y_squeezed = y.squeeze()
y_reshaped = y.reshape(3, 4) # 需要手动知道结果是 (3, 4)print("\nSqueeze result:", y_squeezed.shape) # torch.Size([3, 4])
print("Reshape result:", y_reshaped.shape)   # torch.Size([3, 4])
# squeeze() 自动完成,reshape() 需要手动计算

五、复制与扩展数据:repeat()expand()

除了改变张量的形状,有时还需要沿着某些维度复制数据,以生成一个更大的张量。repeat()expand() 函数都可以实现这一目的,但它们在实现方式和内存使用上有本质的区别。

expand()

expand() 函数通过扩展长度为 1 的维度来创建一个新的、更高维度的张量 视图。它并不会实际分配新的内存来存储复制的数据,因此非常高效。

关键点:

  • 内存高效expand() 返回的是一个视图,与原张量共享底层数据,不产生数据拷贝。
  • 参数含义expand() 的参数指定的是张量的 最终目标形状
  • 使用限制expand() 只能用于扩展大小为 1 的维度(也称为“单例维度”)。对于大小不为 1 的维度,新尺寸必须与原尺寸相同。如果想保持某个维度不变,可以传入 -1 作为该维度的尺寸。
repeat()

repeat() 函数通过在物理内存中 真正地复制 数据来构造一个新张量。它会沿着指定的维度将张量重复指定的次数。

关键点:

  • 数据拷贝repeat() 会创建一个全新的张量,其内容是原张量数据的重复,因此内存占用会相应增加。
  • 参数含义repeat() 的参数指定的是每个维度需要 重复的次数
  • 无限制repeat() 可以对任意维度的张量进行重复,无论其原始大小是否为 1。
用法区别与选择
特性expand()repeat()
内存使用高效,不分配新内存(返回视图)内存消耗大,会创建数据的完整副本
参数含义扩展后的 目标尺寸每个维度的 重复次数
使用限制只能扩展大小为 1 的维度可以重复任意大小的维度
推荐使用场景当需要进行广播(Broadcasting)操作且注重内存效率时,例如将一个偏置向量扩展以匹配一个批次的数据。当需要一个独立的、数据重复的张量副本,并且后续可能需要就地修改其中的部分数据时。
使用示例
import torch# 创建一个包含单例维度的张量
x = torch.tensor([[1], [2], [3]]) # Shape: [3, 1]
print("Original tensor x:\n", x)
print("Original shape:", x.shape)# 使用 expand() 将大小为 1 的维度扩展到 4
# 目标形状是 (3, 4),-1 表示该维度大小不变
expanded_x = x.expand(-1, 4)
print("\nExpanded x shape:", expanded_x.shape) # torch.Size([3, 4])
print("Expanded x:\n", expanded_x)# 使用 repeat() 实现类似效果
# 维度0重复1次(不变),维度1重复4次
repeated_x = x.repeat(1, 4)
print("\nRepeated x (1, 4) shape:", repeated_x.shape) # torch.Size([3, 4])
print("Repeated x (1, 4):\n", repeated_x)# 使用 repeat() 进行更复杂的复制
# 维度0重复2次,维度1重复3次
complex_repeated_x = x.repeat(2, 3)
print("\nRepeated x (2, 3) shape:", complex_repeated_x.shape) # torch.Size([6, 3])
print("Repeated x (2, 3):\n", complex_repeated_x)# expand() 无法作用于大小不为1的维度
try:# 尝试将大小为3的维度扩展到6,会报错x.expand(6, 4)
except RuntimeError as e:print("\nError with expand():", e)

六、组合技巧:先升维再扩展

在实际应用中,一个常见的需求是将一个一维向量复制多次,以构建一个二维矩阵。例如,将一个权重向量应用到批次中的每一个样本上。这个操作可以通过组合“升维”和“扩展/复制”函数来高效实现。这里介绍两种主流的组合方法:view/expandreshape/repeat

方法一:view/reshape + expand (内存高效)

这个组合利用了 expand 函数不复制数据、只创建视图的特性,是实现广播操作的首选。

  1. 升维:首先,使用 view(1, -1)reshape(1, -1) (或者更直观的 unsqueeze(0)) 将一维向量 (N) 变为二维的行向量 (1, N)
  2. 扩展:然后,调用 expand(M, -1)。这会将大小为 1 的第 0 维“扩展” M 次,得到一个 (M, N) 的张量。-1 表示该维度的大小保持不变。

关键点:

  • 整个过程没有发生数据拷贝,返回的是一个共享原始数据的视图,内存效率极高。
  • 由于返回的是视图,并且多个位置共享同一块内存,因此不适合对结果进行就地修改(in-place modification)。
方法二:reshape/view + repeat (数据独立)

这个组合会创建数据的完整物理副本,适用于需要一个独立的、可修改的新张量的场景。

  1. 升维:与方法一相同,先将一维向量 (N) 变为 (1, N)
  2. 复制:然后,调用 repeat(M, 1)。这会将张量在第 0 维上复制 M 次,在第 1 维上复制 1 次(即不复制),最终得到一个 (M, N) 的张量。

关键点:

  • repeat 会实际分配新的内存并复制数据,生成的新张量与原始张量完全独立。
  • 内存开销是 M * N,但好处是你可以自由地修改新张量中的任何元素,而不会影响原始数据。
对比与选择
特性view/reshape + expandreshape/view + repeat
内存使用高效,共享数据,不创建副本消耗大,创建完整的数据副本
数据独立性不独立,是原始数据的视图完全独立,是全新的张量
修改数据通常不应修改,可能会导致错误可以自由、安全地修改
推荐场景广播、只读操作、对内存敏感的场景需要一个可修改的、数据独立的副本时
使用示例
import torch# 1. 创建一个初始的一维向量
x = torch.arange(4) # tensor([0, 1, 2, 3]), Shape: [4]
num_repeats = 3
print(f"Original 1D tensor: {x}\n")# --- 方法一: reshape + expand (内存高效) ---
# 先升维 (4) -> (1, 4),再扩展 (1, 4) -> (3, 4)
x_expanded = x.reshape(1, -1).expand(num_repeats, -1)
print("--- reshape + expand ---")
print("Expanded shape:", x_expanded.shape)
print("Expanded tensor:\n", x_expanded)
# 注意:x_expanded 与 x 共享内存# --- 方法二: reshape + repeat (数据独立) ---
# 先升维 (4) -> (1, 4),再复制 (1, 4) -> (3, 4)
x_repeated = x.reshape(1, -1).repeat(num_repeats, 1)
print("\n--- reshape + repeat ---")
print("Repeated shape:", x_repeated.shape)
print("Repeated tensor:\n", x_repeated)# 验证数据独立性
# 修改 repeated tensor 的一个元素
x_repeated[0, 1] = 99
print("\nModified repeated tensor:\n", x_repeated)
print("Original tensor after modifying repeated:", x) # 原始张量不受影响# 尝试修改 expanded tensor 会引发问题,因为它是一个视图
try:x_expanded[0, 1] = 99
except RuntimeError as e:print(f"\nError modifying expanded tensor: {e}")
http://www.lryc.cn/news/598614.html

相关文章:

  • 如何恢复mysql,避免被研发删库跑路
  • 多模态数据处理系统:用AI读PDF的智能助手系统分析
  • 六、Element-快速入门
  • K8s WebUI 选型:国外 Rancher vs 国内 KubeSphere vs 原生 Dashboard,从部署到使用心得谁更适合企业级场景?
  • 从零用java实现 小红书 springboot vue uniapp(14) 集成阿里云短信验证码
  • Android安全存储:加密文件与SharedPreferences最佳实践
  • 【C++】使用箱线图算法剔除数据样本中的异常值
  • 进程通信----匿名管道
  • 【redis其它面试问题】
  • PHP 与 Vue.js 结合的前后端分离架构
  • 工具分享02 | Python批量文件重命名工具
  • 电商接口什么意思?
  • 数据所有权与用益权分离:数字经济时代的权利博弈与“商业机遇”
  • Claude Code是如何做上下文工程的?
  • Maven Scope标签:解锁Java项目依赖管理的秘密武器
  • [嵌入式embed]ST官网-根据指定固件名下载固件库-STSWSTM32054[STM32F10x_StdPeriph_Lib_V3.5.0]
  • 使用maven-shade-plugin解决依赖版本冲突
  • RCLAMP0504S.TCT 升特半导体TVS二极管 无损传输+军工防护+纳米护甲 ESD防护芯片
  • 陕西地区特种作业操作证考试题库及答案(登高架设作业)
  • Product Hunt 每日热榜 | 2025-07-24
  • 2025年人形机器人动捕技术研讨会于7月31日在京召开
  • 火语言 RPA 在日常运维中的实践
  • ESP32使用 vscode IDF 创建项目到烧录运行全过程
  • 优选算法:移动零
  • 使用ffmpeg转码h265后mac默认播放器不支持问题
  • Mac电脑使用IDEA启动服务后,报service异常
  • 从零构建 Node20+pnpm+pm2 环境镜像:基于 Dockerfile 的两种方案及持久化配置指南
  • 开源Qwen凌晨暴击闭源Claude!刷新AI编程SOTA,支持1M上下文
  • Vue3实现视频播放弹窗组件,支持全屏播放,音量控制,进度条自定义样式,适配浏览器小窗播放,视频大小自适配,缓冲loading,代码复制即用
  • 合泰单片机怎么样