PyTorch常用Tensor形状变换函数详解
PyTorch常用Tensor形状变换函数详解
在PyTorch中,对张量(Tensor)进行形状变换是深度学习模型构建中不可或缺的一环。无论是为了匹配网络层的输入要求,还是为了进行数据预处理和维度调整,都需要灵活运用各种形状变换函数。本文将系统介绍几个核心的形状变换函数,并深入剖析它们的用法区别与关键点。
一、改变形状与元素数量:view()
与 reshape()
view()
和 reshape()
是最常用的两个用于重塑张量形状的函数。它们都可以改变张量的维度,但前提是新旧张量的元素总数必须保持一致。尽管功能相似,但它们在工作机制上存在关键差异。
view()
view()
函数返回一个具有新形状的张量,这个新张量与原始张量 共享底层数据。这意味着修改其中一个张量的数据,另一个也会随之改变。
关键点:
- 内存共享:
view()
保证返回的张量与原张量共享数据,不会创建新的内存副本,因此效率很高。 - 连续性要求:
view()
只能作用于在内存中 连续 (contiguous) 的张量。对于一个非连续的张量(例如通过transpose
操作后得到的张量),直接使用view()
会引发错误。 在这种情况下,需要先调用.contiguous()
方法将其变为连续的,然后再使用view()
。
reshape()
reshape()
函数同样用于改变张量的形状,但它更加灵活和安全。
关键点:
- 智能处理:
reshape()
可以处理连续和非连续的张量。 - 视图或副本:当作用于连续张量时,
reshape()
的行为类似于view()
,返回一个共享数据的视图。 然而,当作用于非连续张量时,reshape()
会创建一个新的、具有所需形状的连续张量,并复制原始数据,此时返回的是一个副本,与原张量不再共享内存。 - 不确定性:
reshape()
的语义是它 可能 会也 可能不会 共享存储空间,事先无法确定。
用法区别与选择
特性 | view() | reshape() |
---|---|---|
内存共享 | 总是共享(返回视图) | 可能共享(视图),也可能不共享(副本) |
对非连续张量 | 抛出错误 | 自动创建副本 |
推荐使用场景 | 当确定张量是连续的,并且需要保证内存共享以提升效率时。 | 当不确定张量的连续性,或希望代码更健干时,reshape() 是更安全的选择。 |
使用示例
import torch# 创建一个连续的张量
x = torch.arange(12) # tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])# 使用 view() 和 reshape()
x_view = x.view(3, 4)
x_reshape = x.reshape(3, 4)print("Original is contiguous:", x.is_contiguous()) # True
print("x_view:\n", x_view)
print("x_reshape:\n", x_reshape)# 创建一个非连续的张量
y = torch.arange(12).reshape(3, 4).t() # .t() 是 transpose(0, 1) 的简写
print("\nOriginal y is contiguous:", y.is_contiguous()) # False# 对非连续张量使用 reshape() - 成功
y_reshape = y.reshape(3, 4)
print("y_reshape:\n", y_reshape)
print("y_reshape is contiguous:", y_reshape.is_contiguous()) # True# 对非连续张量使用 view() - 报错
try:y_view = y.view(3, 4)
except RuntimeError as e:print("\nError with view():", e)# 先用 .contiguous() 再用 view() - 成功
y_view_contiguous = y.contiguous().view(3, 4)
print("y_view_contiguous:\n", y_view_contiguous)
二、交换维度:transpose()
与 permute()
与 reshape
或 view
不同,transpose
和 permute
用于重新排列张量的维度,而不是像“拉伸”或“压缩”数据那样改变形状。
transpose()
transpose()
函数专门用于 交换 张量的两个指定维度。
用法: tensor.transpose(dim0, dim1)
关键点:
- 两两交换:每次只能交换两个维度。
- 共享数据:返回的张量与原张量共享底层数据,但通常会导致张量在内存中变为非连续。
permute()
permute()
函数则提供了更强大的维度重排能力,可以一次性对所有维度进行任意顺序的重新排列。
用法: tensor.permute(dims)
,其中 dims
是一个包含所有原始维度索引的新顺序。
关键点:
- 任意重排:必须为所有维度提供新的顺序。
- 通用性:
transpose(dim0, dim1)
可以看作是permute
的一个特例。 - 共享数据与非连续性:同样地,
permute
返回的也是一个共享数据的视图,并且通常会使张量变为非连续。
用法区别与选择
- 当只需要交换两个维度时,使用
transpose()
更直观。 - 当需要进行更复杂的维度重排,例如将
(B, C, H, W)
变为(B, H, W, C)
时,必须使用permute()
。
重要提示:由于 transpose()
和 permute()
经常产生非连续的张量,如果后续需要使用 view()
,必须先调用 .contiguous()
方法。
使用示例
import torch# 假设张量形状为 (batch, channel, height, width)
x = torch.randn(2, 3, 4, 5) # Shape: [2, 3, 4, 5]# 使用 transpose() 交换 height 和 width 维度
# 原始维度: 0, 1, 2, 3 -> 交换维度 2 和 3
x_transposed = x.transpose(2, 3)
print("Original shape:", x.shape) # torch.Size([2, 3, 4, 5])
print("Transposed shape:", x_transposed.shape) # torch.Size([2, 3, 5, 4])# 使用 permute() 将 (B, C, H, W) 变为 (B, H, W, C)
# 原始维度: 0, 1, 2, 3 -> 新维度顺序: 0, 2, 3, 1
x_permuted = x.permute(0, 2, 3, 1)
print("Permuted shape:", x_permuted.shape) # torch.Size([2, 4, 5, 3])
三、增减维度:unsqueeze()
与 squeeze()
这两个函数用于添加或移除长度为 1 的维度,这在处理批处理数据或需要广播时非常有用。
unsqueeze()
unsqueeze()
用于在指定位置 添加 一个长度为 1 的维度。
用法: tensor.unsqueeze(dim)
关键点:
- 它会在
dim
参数指定的位置插入一个新维度。 - 常用于为单个样本数据添加
batch
维度,或为二维张量添加channel
维度,以符合模型的输入格式。
squeeze()
squeeze()
用于 移除 所有长度为 1 的维度。
用法:
tensor.squeeze()
: 移除所有长度为 1 的维度。tensor.squeeze(dim)
: 只在指定dim
位置移除长度为 1 的维度,如果该维度长度不为 1,则张量不变。
关键点:
- 这是一个降维操作,可以方便地去除多余的、长度为1的维度。
unsqueeze()
和squeeze()
互为逆操作。- 返回的张量同样与原张量共享数据。
使用示例
import torch# 创建一个形状为 (3, 4) 的张量
x = torch.randn(3, 4)
print("Original shape:", x.shape) # torch.Size([3, 4])# 使用 unsqueeze() 在第 0 维添加 batch 维度
x_unsqueezed_0 = x.unsqueeze(0)
print("Unsqueeze at dim 0:", x_unsqueezed_0.shape) # torch.Size([1, 3, 4])# 使用 unsqueeze() 在第 1 维添加 channel 维度
x_unsqueezed_1 = x.unsqueeze(1)
print("Unsqueeze at dim 1:", x_unsqueezed_1.shape) # torch.Size([3, 1, 4])# --- squeeze ---
y = torch.randn(1, 3, 1, 4)
print("\nOriginal y shape:", y.shape) # torch.Size([1, 3, 1, 4])# 使用 squeeze() 移除所有长度为 1 的维度
y_squeezed_all = y.squeeze()
print("Squeeze all ones:", y_squeezed_all.shape) # torch.Size([3, 4])# 使用 squeeze(dim) 只移除指定位置的维度
y_squeezed_dim = y.squeeze(0) # 移除第 0 维
print("Squeeze at dim 0:", y_squeezed_dim.shape) # torch.Size([3, 1, 4])
四、view
/reshape
与 squeeze
/unsqueeze
的关系:可替代性讨论
从最终的形状结果来看,view
和 reshape
在很多情况下确实可以实现与 squeeze
和 unsqueeze
相同的效果。然而,它们的设计理念和使用场景存在显著差异,一般不建议混用。
使用 reshape
替代 unsqueeze
将一个形状为 (A, B)
的张量通过 unsqueeze(0)
变为 (1, A, B)
,可以等价地使用 reshape(1, A, B)
来实现。
关键区别:
- 可读性与意图:
unsqueeze(dim)
的意图非常明确——在指定位置插入一个新维度。这使得代码更易于理解。而reshape()
需要提供完整的最终形状,阅读者需要通过对比新旧形状才能理解其操作意图。 - 便利性:使用
unsqueeze
时,无需知道张量的其他维度尺寸。而使用reshape
,则必须知道所有维度的大小才能构建新的形状参数。
使用 reshape
替代 squeeze
将一个形状为 (1, A, B, 1)
的张量通过 squeeze()
变为 (A, B)
,可以等价地使用 reshape(A, B)
实现。
关键区别:
- 自动化与便利性:
squeeze()
的核心优势在于其自动化。它会自动移除所有大小为 1 的维度,使用者无需预先知道哪些维度是 1。如果使用reshape
,则必须手动计算出目标形状,这在处理动态或未知的输入形状时会非常繁琐。 - 条件性操作:
squeeze(dim)
只在指定维度大小为 1 时才执行操作,否则张量保持不变。reshape
不具备这种条件判断能力,它会强制改变形状,如果元素总数不匹配则会报错。
虽然 reshape
功能更强大,理论上可以模拟 squeeze
和 unsqueeze
的操作,但强烈建议使用专用的函数。
- 当你的意图是添加或移除单个维度时,请使用
unsqueeze
和squeeze
。这不仅使代码更清晰、更具可读性,还能利用squeeze
的自动检测和条件操作特性,让代码更健壮。 - 只有当你需要进行更复杂的、非增减单一维度的形状重塑时,才应使用
reshape
或view
。
使用示例
import torch# --- unsqueeze vs reshape ---
x = torch.randn(3, 4) # Shape: [3, 4]# 目标: 添加 batch 维度 -> (1, 3, 4)
x_unsqueezed = x.unsqueeze(0)
x_reshaped = x.reshape(1, 3, 4)print("Unsqueeze result:", x_unsqueezed.shape) # torch.Size([1, 3, 4])
print("Reshape result:", x_reshaped.shape) # torch.Size([1, 3, 4])
# 结果相同,但 unsqueeze(0) 意图更明确# --- squeeze vs reshape ---
y = torch.randn(1, 3, 1, 4) # Shape: [1, 3, 1, 4]# 目标: 移除所有大小为1的维度 -> (3, 4)
y_squeezed = y.squeeze()
y_reshaped = y.reshape(3, 4) # 需要手动知道结果是 (3, 4)print("\nSqueeze result:", y_squeezed.shape) # torch.Size([3, 4])
print("Reshape result:", y_reshaped.shape) # torch.Size([3, 4])
# squeeze() 自动完成,reshape() 需要手动计算
五、复制与扩展数据:repeat()
与 expand()
除了改变张量的形状,有时还需要沿着某些维度复制数据,以生成一个更大的张量。repeat()
和 expand()
函数都可以实现这一目的,但它们在实现方式和内存使用上有本质的区别。
expand()
expand()
函数通过扩展长度为 1 的维度来创建一个新的、更高维度的张量 视图。它并不会实际分配新的内存来存储复制的数据,因此非常高效。
关键点:
- 内存高效:
expand()
返回的是一个视图,与原张量共享底层数据,不产生数据拷贝。 - 参数含义:
expand()
的参数指定的是张量的 最终目标形状。 - 使用限制:
expand()
只能用于扩展大小为 1 的维度(也称为“单例维度”)。对于大小不为 1 的维度,新尺寸必须与原尺寸相同。如果想保持某个维度不变,可以传入-1
作为该维度的尺寸。
repeat()
repeat()
函数通过在物理内存中 真正地复制 数据来构造一个新张量。它会沿着指定的维度将张量重复指定的次数。
关键点:
- 数据拷贝:
repeat()
会创建一个全新的张量,其内容是原张量数据的重复,因此内存占用会相应增加。 - 参数含义:
repeat()
的参数指定的是每个维度需要 重复的次数。 - 无限制:
repeat()
可以对任意维度的张量进行重复,无论其原始大小是否为 1。
用法区别与选择
特性 | expand() | repeat() |
---|---|---|
内存使用 | 高效,不分配新内存(返回视图) | 内存消耗大,会创建数据的完整副本 |
参数含义 | 扩展后的 目标尺寸 | 每个维度的 重复次数 |
使用限制 | 只能扩展大小为 1 的维度 | 可以重复任意大小的维度 |
推荐使用场景 | 当需要进行广播(Broadcasting)操作且注重内存效率时,例如将一个偏置向量扩展以匹配一个批次的数据。 | 当需要一个独立的、数据重复的张量副本,并且后续可能需要就地修改其中的部分数据时。 |
使用示例
import torch# 创建一个包含单例维度的张量
x = torch.tensor([[1], [2], [3]]) # Shape: [3, 1]
print("Original tensor x:\n", x)
print("Original shape:", x.shape)# 使用 expand() 将大小为 1 的维度扩展到 4
# 目标形状是 (3, 4),-1 表示该维度大小不变
expanded_x = x.expand(-1, 4)
print("\nExpanded x shape:", expanded_x.shape) # torch.Size([3, 4])
print("Expanded x:\n", expanded_x)# 使用 repeat() 实现类似效果
# 维度0重复1次(不变),维度1重复4次
repeated_x = x.repeat(1, 4)
print("\nRepeated x (1, 4) shape:", repeated_x.shape) # torch.Size([3, 4])
print("Repeated x (1, 4):\n", repeated_x)# 使用 repeat() 进行更复杂的复制
# 维度0重复2次,维度1重复3次
complex_repeated_x = x.repeat(2, 3)
print("\nRepeated x (2, 3) shape:", complex_repeated_x.shape) # torch.Size([6, 3])
print("Repeated x (2, 3):\n", complex_repeated_x)# expand() 无法作用于大小不为1的维度
try:# 尝试将大小为3的维度扩展到6,会报错x.expand(6, 4)
except RuntimeError as e:print("\nError with expand():", e)
六、组合技巧:先升维再扩展
在实际应用中,一个常见的需求是将一个一维向量复制多次,以构建一个二维矩阵。例如,将一个权重向量应用到批次中的每一个样本上。这个操作可以通过组合“升维”和“扩展/复制”函数来高效实现。这里介绍两种主流的组合方法:view
/expand
和 reshape
/repeat
。
方法一:view
/reshape
+ expand
(内存高效)
这个组合利用了 expand
函数不复制数据、只创建视图的特性,是实现广播操作的首选。
- 升维:首先,使用
view(1, -1)
或reshape(1, -1)
(或者更直观的unsqueeze(0)
) 将一维向量(N)
变为二维的行向量(1, N)
。 - 扩展:然后,调用
expand(M, -1)
。这会将大小为 1 的第 0 维“扩展” M 次,得到一个(M, N)
的张量。-1
表示该维度的大小保持不变。
关键点:
- 整个过程没有发生数据拷贝,返回的是一个共享原始数据的视图,内存效率极高。
- 由于返回的是视图,并且多个位置共享同一块内存,因此不适合对结果进行就地修改(in-place modification)。
方法二:reshape
/view
+ repeat
(数据独立)
这个组合会创建数据的完整物理副本,适用于需要一个独立的、可修改的新张量的场景。
- 升维:与方法一相同,先将一维向量
(N)
变为(1, N)
。 - 复制:然后,调用
repeat(M, 1)
。这会将张量在第 0 维上复制 M 次,在第 1 维上复制 1 次(即不复制),最终得到一个(M, N)
的张量。
关键点:
repeat
会实际分配新的内存并复制数据,生成的新张量与原始张量完全独立。- 内存开销是
M * N
,但好处是你可以自由地修改新张量中的任何元素,而不会影响原始数据。
对比与选择
特性 | view/reshape + expand | reshape/view + repeat |
---|---|---|
内存使用 | 高效,共享数据,不创建副本 | 消耗大,创建完整的数据副本 |
数据独立性 | 不独立,是原始数据的视图 | 完全独立,是全新的张量 |
修改数据 | 通常不应修改,可能会导致错误 | 可以自由、安全地修改 |
推荐场景 | 广播、只读操作、对内存敏感的场景 | 需要一个可修改的、数据独立的副本时 |
使用示例
import torch# 1. 创建一个初始的一维向量
x = torch.arange(4) # tensor([0, 1, 2, 3]), Shape: [4]
num_repeats = 3
print(f"Original 1D tensor: {x}\n")# --- 方法一: reshape + expand (内存高效) ---
# 先升维 (4) -> (1, 4),再扩展 (1, 4) -> (3, 4)
x_expanded = x.reshape(1, -1).expand(num_repeats, -1)
print("--- reshape + expand ---")
print("Expanded shape:", x_expanded.shape)
print("Expanded tensor:\n", x_expanded)
# 注意:x_expanded 与 x 共享内存# --- 方法二: reshape + repeat (数据独立) ---
# 先升维 (4) -> (1, 4),再复制 (1, 4) -> (3, 4)
x_repeated = x.reshape(1, -1).repeat(num_repeats, 1)
print("\n--- reshape + repeat ---")
print("Repeated shape:", x_repeated.shape)
print("Repeated tensor:\n", x_repeated)# 验证数据独立性
# 修改 repeated tensor 的一个元素
x_repeated[0, 1] = 99
print("\nModified repeated tensor:\n", x_repeated)
print("Original tensor after modifying repeated:", x) # 原始张量不受影响# 尝试修改 expanded tensor 会引发问题,因为它是一个视图
try:x_expanded[0, 1] = 99
except RuntimeError as e:print(f"\nError modifying expanded tensor: {e}")