EP02:【DL 第二弹】张量的索引、分片、合并以及维度调整
一、张量的符号索引
张量作为 PyTorch 中最核心的数据结构,其本质是有序的多维数据集合。就像我们在生活中通过地址找到具体的房间一样,张量的索引就是通过位置编号定位特定元素的过程。对于不同维度的张量,索引逻辑既存在共性,也有维度带来的差异。
1.1 一维张量索引
一维张量可以简单理解为拉长的数组,它的索引规则与 Python 中的列表、元组等原生序列完全一致,核心格式是[start: end: step]
,这三个参数分别控制索引的起点、终点和步长。
t1 = torch.arange(1, 11)
print(f"t1:{t1}")# 从左到右,从零开始
print(f"t1[0]:{t1[0]}")
print(f"t1[0]的类型:{type(t1[0])}")
# 切片
print(f"t1[1:8] 取2-9号元素,且左闭右开:{t1[1:8]}")
print(f"t1[1:8:2] 取2-9号元素,左闭右开,且隔2取一个数:{t1[1:8:2]}")
print(f"t1[1::2] 从2号开始取完,且隔2取一个数:{t1[1::2]}")
print(f"t1[:8:2]:{t1[:8:2]}")# ×1:step位必须大于0
# print(f"t1[9:1:-1]:{t1[9:1:-1]}") # ValueError: step must be greater than zero
- 运行结果:
t1:tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
t1[0]:1
t1[0]的类型:<class 'torch.Tensor'>
t1[1:8] 取2-9号元素,且左闭右开:tensor([2, 3, 4, 5, 6, 7, 8])
t1[1:8:2] 取2-9号元素,左闭右开,且隔2取一个数:tensor([2, 4, 6, 8])
t1[1::2] 从2号开始取完,且隔2取一个数:tensor([ 2, 4, 6, 8, 10])
t1[:8:2]:tensor([1, 3, 5, 7])
- 示例解读:
一维张量的索引从0开始计数(即从左到右,从零开始)。需要注意的是,即使取出的是单个元素,其类型仍然是
torch.Tensor
(张量类型),而非 Python 原生的整数,这是因为 PyTorch 的运算始终围绕张量进行,单个元素也会被封装为张量以保持一致性。
切片操作是索引的延伸,通过start:end
可以选取一个连续的子序列,且遵循左闭右开原则 —— 即包含start
对应的元素,不包含end
对应的元素。步长step
则用于控制间隔选取;如果start
或end
省略,则表示从起点开始或到终点结束。
另外,需要特别注意的是,步长step
必须为正数。如果尝试使用负数步长(如t1[9:1:-1]),会直接报错,这是因为 PyTorch 的一维张量索引暂不支持通过负步长实现从右到左的反向切片(与 Python 列表不同)。
1.2 二维张量索引
二维张量可以想象成一张表格,包含行和列两个维度。它的索引逻辑是一维索引的扩展 —— 通过逗号分隔两个维度的索引参数,分别对应行索引和列索引,格式为[行索引, 列索引]
,其中每个维度的索引规则与一维张量完全一致。
t2 = torch.arange(1, 10).reshape(3, 3)
print(f"t2:{t2}")
print(f"t2[0, 1] 第1行第2列:{t2[0, 1]}")
print(f"t2[0, ::2] 第1行,且隔2取一个数:{t2[0, ::2]}")
print(f"t2[0, [0, 2]] 第1行的第1列和第3列:{t2[0, [0, 2]]}")
print(f"t2[::2, ::2]:{t2[::2, ::2]}")
print(f"t2[[0, 2], 0]:{t2[0, [0, 2]]}")
- 运行结果:
t2:tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
t2[0, 1] 第1行第2列:2
t2[0, ::2] 第1行,且隔2取一个数:tensor([1, 3])
t2[0, [0, 2]] 第1行的第1列和第3列:tensor([1, 3])
t2[::2, ::2]:tensor([[1, 3],[7, 9]])
t2[[0, 2], 0]:tensor([1, 3])
- 示例解读:
除了切片,还可以通过列表指定具体的索引位置,实现非连续选取。比如t2[0, [0, 2]]直接指定第 0 行的第 0 列和第 2 列,结果同样是 [1, 3],这种方式比切片更灵活,适合选取分散的元素。
需要注意的是,二维索引中两个维度的操作是独立的,行索引决定选取哪些行,列索引决定选取这些行中的哪些列,最终结果的形状由两个维度的选取结果共同决定。
1.3 三维张量索引
三维张量可以理解为一叠表格—— 即多个二维张量(矩阵)堆叠而成的序列。因此,它的索引需要三个参数,分别对应第几个矩阵 矩阵中的行 矩阵中的列,格式为[矩阵索引, 行索引, 列索引]
,每个维度的索引规则仍与一维张量一致。
三维索引的核心是逐层定位:先确定要操作的矩阵,再在该矩阵中通过行和列索引定位元素,维度的增加只是多了一层选择范围,但每层的索引逻辑与一维、二维保持一致。
t3 = torch.arange(1, 28).reshape(3, 3, 3)
print(f"t3:{t3}")
print(f"t3[1, 1, 1] 第2个矩阵第2行第2列:{t3[1, 1, 1]}")
print(f"t3[1, ::2, ::2]:{t3[1, ::2, ::2]}")
- 运行结果:
t3:tensor([[[ 1, 2, 3],[ 4, 5, 6],[ 7, 8, 9]],[[10, 11, 12],[13, 14, 15],[16, 17, 18]],[[19, 20, 21],[22, 23, 24],[25, 26, 27]]])
t3[1, 1, 1] 第2个矩阵第2行第2列:14
t3[1, ::2, ::2]:tensor([[10, 12],[16, 18]])
二、张量的函数索引
除了通过[]符号进行索引,PyTorch 还提供了torch.index_select()
函数用于更灵活的索引操作。这种函数式索引的核心优势是可以通过索引张量批量指定要选取的位置,尤其适合处理非连续、不规则的索引需求。
torch.index_select(input, dim, index)
- 功能:通过指定索引在张量的特定维度上选取元素,返回一个新的张量(原张量的视图)。
- 参数说明:
input
:输入的张量。dim
:指定进行索引操作的维度。index
:包含索引值的张量,用于指定在dim
维度上选取哪些元素。
函数索引与符号索引的核心区别在于:符号索引中的[start:end:step]
更适合连续切片,而index_select
通过index
张量可以轻松实现非连续索引,且索引逻辑更清晰(明确指定维度)。此外,index
必须是张量类型,这也符合 PyTorch 中张量操作张量的设计理念,确保运算可以在 GPU 上高效执行。
import torcht1 = torch.arange(1, 11)
print(f"t1:{t1}")indices = torch.tensor([1, 2])
t1_ = torch.index_select(t1, 0, indices)
print(f"t1_:{t1_}")print('--'*50)t2 = torch.arange(12).reshape(4, 3)
print(f"t2:{t2}")
t2_ = torch.index_select(t2, 0, indices)
t2_1 = torch.index_select(t2, 1, indices)
print(f"t2_:{t2_}")
print(f"t2_1:{t2_1}")
- 运行结果:
t1:tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
t1_:tensor([2, 3])
----------------------------------------------------------------------------------------------------
t2:tensor([[ 0, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
t2_:tensor([[3, 4, 5],[6, 7, 8]])
t2_1:tensor([[ 1, 2],[ 4, 5],[ 7, 8],[10, 11]])
三、view()
方法
在处理张量时,我们经常需要改变其形状(如将 2x3 的张量转为 3x2),但又希望避免复制数据以节省内存 ——view()
方法就是为解决这个问题而生的。它的核心作用是生成原张量的视图(view),即一个与原张量共享数据存储空间、但形状不同的新张量。
视图的本质是浅拷贝:原张量和视图共用同一块内存,因此修改其中一个,另一个会同步变化。
torch.Tensor.view(*shape)
- 功能:返回一个与原张量共享数据存储空间的新视图,可改变张量的形状(结构),但不改变数据本身。
- 参数说明:
*shape
:表示新张量的形状,为一个或多个整数组成的序列,需保证新形状的元素总数与原张量相同。
import torcht1 = torch.arange(6).reshape(2, 3)
print(f"t1: {t1}")
t1_ = t1.view(3, 2)
print(f"t1_: {t1_}")
t1_1 = t1.view(1, 2, 3)
print(f"t1_1: {t1_1}")print('--'*50)
# *1. view() 构建的是一个数据相同,但形状不同的视图
t1[0] = 100
print(f"t1: {t1}")
print(f"t1_: {t1_}")
- 运行结果:
t1: tensor([[0, 1, 2],[3, 4, 5]])
t1_: tensor([[0, 1],[2, 3],[4, 5]])
t1_1: tensor([[[0, 1, 2],[3, 4, 5]]])
----------------------------------------------------------------------------------------------------
t1: tensor([[100, 100, 100],[ 3, 4, 5]])
t1_: tensor([[100, 100],[100, 3],[ 4, 5]])
view()
的使用非常灵活,只要新形状的元素总数与原张量一致即可。需要注意的是,view()
要求原张量的内存布局是连续的(即元素在内存中连续存储),如果原张量经过某些操作(如转置)导致内存不连续,view()
可能会报错,此时可以先用contiguous()
方法整理内存后再使用view()
。
view()
的设计初衷是高效处理形状变换:在深度学习中,张量的形状调整非常频繁(如全连接层前需将图像张量展平),通过共享内存可以避免不必要的数据复制,大幅提升运算效率。这也是为什么后续介绍的很多张量切分方法(如chunk()
、split()
)返回的都是视图 —— 它们都依赖view()
的内存共享机制。
四、张量的分片函数
当需要将一个大张量拆分成多个小张量时,PyTorch 提供了chunk()
和split()
两个常用函数,它们的共同特点是返回原张量的视图(而非复制数据),但在切分逻辑上有所区别。
4.1 chunk()
分块
torch.chunk(input, chunks, dim=0)
- 功能:按照指定维度对张量进行均匀切分,返回切分后的张量组成的元组(各元素均为原张量的视图)。
- 参数说明:
input
:输入的张量。chunks
:表示要切分的块数。dim
:指定进行切分操作的维度,默认为0。
t1 = torch.arange(12).reshape(4, 3)
print(f"t1:{t1}")
t1_ = torch.chunk(t1, 4, dim=0)
print(f"t1_ 在第0维度上进行4等分:{t1_}")# *1. chunk() 函数返回的结果是一个视图,而非一个新对象
t1_[0][0][0] = 100
print(f"t1_:{t1_}")
print(f"t1:{t1}")
# *2. 若原张量无法均分,chunk() 函数不会报错,而是返回其他均分结果
t1_1 = torch.chunk(t1, 8, dim=0)
print(f"t1_1:{t1_1}")
- 运行结果:
t1:tensor([[ 0, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
t1_ 在第0维度上进行4等分:(tensor([[0, 1, 2]]), tensor([[3, 4, 5]]), tensor([[6, 7, 8]]), tensor([[ 9, 10, 11]]))
t1_:(tensor([[100, 1, 2]]), tensor([[3, 4, 5]]), tensor([[6, 7, 8]]), tensor([[ 9, 10, 11]]))
t1:tensor([[100, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
t1_1:(tensor([[100, 1, 2]]), tensor([[3, 4, 5]]), tensor([[6, 7, 8]]), tensor([[ 9, 10, 11]]))
- 示例解读:
chunk()
的一个重要特性是无法均分时的容错处理:如果原张量在指定维度的长度不能被chunks整除,它不会报错,而是尽可能均匀地分配,最后一份可能比其他份少。例如t1在第 0 维长度为 4,当chunks=8
时,由于 4 < 8,无法切分 8 份,因此torch.chunk(t1, 8, dim=0)
会返回与原张量行数相同的 4 份(即不进行额外切分)。
此外,chunk()
返回的是视图,因此修改分片会影响原张量。比如代码中t1_[0][0][0] = 100
,原张量t1
的第 0 行第 0 列元素也会变成 100,这再次体现了视图共享内存的特性。
4.2 split()
拆分
split()
比chunk()
更灵活,它既可以像chunk()
一样进行均匀切分,也可以通过列表指定自定义切分长度。
torch.split(tensor, split_size_or_sections, dim=0)
- 功能:按照指定维度对张量进行拆分,既可以进行均匀拆分,也可以进行自定义拆分,返回拆分后的张量组成的元组(各元素均为原张量的视图)。
- 参数说明:
tensor
:输入的张量。split_size_or_sections
:若为整数,则表示在指定维度上每个拆分块的大小(均匀拆分);若为列表,则列表中的元素表示各拆分块在指定维度上的大小(自定义拆分)。dim
:指定进行拆分操作的维度,默认为0。
t2 = torch.arange(12).reshape(4, 3)
print(f"t2:{t2}")
t2_ = torch.split(t2, 2, 0)
print(f"t2_:{t2_}")
t2_1 = torch.split(t2, [1, 3], 0)
print(f"t2_1:{t2_1}")
print(f"[1, 1, 2]:{torch.split(t2, [1, 1, 2], 0)}")
t2_2 = torch.split(t2, [1, 2], 1)
print(f"t2_2:{t2_2}")# *1. split() 函数返回的结果也是一个视图
t2_2[0][0] = 100
print(f"t2:{t2}")
- 运行结果:
t2:tensor([[ 0, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
t2_:(tensor([[0, 1, 2],[3, 4, 5]]), tensor([[ 6, 7, 8],[ 9, 10, 11]]))
t2_1:(tensor([[0, 1, 2]]), tensor([[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]]))
[1, 1, 2]:(tensor([[0, 1, 2]]), tensor([[3, 4, 5]]), tensor([[ 6, 7, 8],[ 9, 10, 11]]))
t2_2:(tensor([[0],[3],[6],[9]]), tensor([[ 1, 2],[ 4, 5],[ 7, 8],[10, 11]]))
t2:tensor([[100, 1, 2],[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
- 示例解读:
split()
同样支持对不同维度切分,例如torch.split(t2, [1, 2], 1)
在第 1 维(列)上切分,第一份 1 列,第二份 2 列,结果为两个 4x1 和 4x2 的张量。
与chunk()
类似,split()
返回的也是视图,修改分片会同步影响原张量(如代码中修改t2_2[0][0] = 100
后,t2
的第 0 行第 0 列元素变为 100)。
split()
与chunk()
的核心区别在于灵活性:chunk()
只能指定份数,split()
可以直接指定每份的长度,因此在需要非均匀切分时,split()
是更优选择。
五、张量的合并操作
与分片相反,合并操作用于将多个张量组合成一个更大的张量,PyTorch 中常用的有cat()
(拼接)和stack()
(堆叠),二者的核心区别在于是否改变张量的维度。
5.1 cat()
拼接
torch.cat(tensors, dim=0)
- 功能:将多个张量在指定维度上进行拼接(元素堆积),返回一个新的张量,拼接后张量的维度与原张量相同。
- 参数说明:
tensors
:需要进行拼接的张量组成的序列(如列表),这些张量需在除拼接维度外的其他维度上形状相同。dim
:指定进行拼接操作的维度,默认为0。
a = torch.zeros(2, 3)
b = torch.ones(2, 3)
c = torch.zeros(3, 3)
print(f"a+b 按行:{torch.cat([a,b])}")
print(f"a+b 按列:{torch.cat([a,b], dim=1)}")
- 运行结果:
a+b 按行:tensor([[0., 0., 0.],[0., 0., 0.],[1., 1., 1.],[1., 1., 1.]])
a+b 按列:tensor([[0., 0., 0., 1., 1., 1.],[0., 0., 0., 1., 1., 1.]])
- 示例解读:
cat()
的关键要求是:除了拼接维度外,其他维度的长度必须完全一致。例如a
(2x3)和c
(3x3)可以在dim=0
拼接(因为列数都是 3),结果是 5 行 3 列;但如果尝试在dim=1
拼接,会因行数不同(2 vs 3)而报错。
5.2 stack()
堆叠
torch.stack(tensors, dim=0)
- 功能:将多个形状相同的张量在指定维度上进行堆叠,生成一个新的更高维度的张量(维度数比原张量多1)。
- 参数说明:
tensors
:需要进行堆叠的张量组成的序列(如列表),这些张量必须具有完全相同的形状。dim
:指定进行堆叠操作的维度,即新维度插入的位置,默认为0。
a = torch.zeros(2, 3)
b = torch.ones(2, 3)
c = torch.zeros(3, 3)
print(f"a+b 堆到一个三维张量中:{torch.stack([a, b])}")# ×1. 堆叠时,必须保证两个张量的形状一致
print(f"a+c 拼接:{torch.cat([a, c])}")
# print(f"a+c 堆叠:{torch.stack([a, c])}") # RuntimeError: stack expects each tensor to be equal size, but got [2, 3] at entry 0 and [3, 3] at entry 1
- 运行结果:
a+b 堆到一个三维张量中:tensor([[[0., 0., 0.],[0., 0., 0.]],[[1., 1., 1.],[1., 1., 1.]]])
a+c 拼接:tensor([[0., 0., 0.],[0., 0., 0.],[0., 0., 0.],[0., 0., 0.],[0., 0., 0.]])
- 示例解读:
stack()
的要求比cat()
更严格:所有输入张量的形状必须完全相同。例如a
(2x3)和c
(3x3)形状不同,torch.stack([a, c])
会直接报错,而cat()
在满足其他维度一致时可以拼接。这是因为堆叠是将整个张量作为元素放入新维度,形状不一致会导致元素无法对齐。
5.3 两者区别
简单来说,cat()
是平面拼接,不增加维度,只是在现有维度上延长;stack()
是立体堆叠,增加维度,将输入张量作为更高维度的子元素。例如 2x3 的a
和b
:cat后是 4x3 或 2x6(二维),stack后是 2x2x3(三维)。
cat()
适合将同结构的数据拼接成更长的序列(如将多个批次的样本合并);stack()
适合将多个同形状的张量打包成一个整体(如将多张相同大小的图片堆叠成一个批量图片张量)。
六、张量的维度变换
在深度学习中,张量的维度往往需要根据运算需求调整(如增加批次维度、删除冗余维度),squeeze()
和unsqueeze()
是实现这一目标的常用工具,分别用于降维和升维。
squeeze()
和unsqueeze()
是维度调整的轻量工具,通过增加或删除尺寸为 1 的维度,让张量形状适配不同的运算场景,且操作过程中不会改变张量的元素数据。
6.1 squeeze()
删除不必要的维度
torch.squeeze(input, dim=None)
- 功能:删除张量中所有大小为1的维度(若未指定
dim
),或删除指定维度中大小为1的维度(若指定dim
),返回一个新的张量(可能为原张量的视图)。 - 参数说明:
input
:输入的张量。dim
(可选):指定要删除的维度,若该维度大小不为1,则不进行操作;若不指定,默认删除所有大小为1的维度。
t1 = torch.arange(4)
print(f"t1:{t1}")
t1_ = t1.reshape(1, 4)
print(f"t1_:{t1_}")
print(f"t1_的维度:{t1_.ndim}")
t1_1 = torch.squeeze(t1_)
print(f"t1_1:{t1_1}")
print(f"t1_1的维度:{t1_1.ndim}")t2 = torch.zeros(1, 1, 3, 1)
print(f"t2:{t2}")
print(f"t2 降维:{torch.squeeze(t2)}")t3 = torch.ones(1, 1, 3, 2, 1, 2)
print(f"t3:{t3}")
print(f"t3 降维:{torch.squeeze(t3)},形状:{torch.squeeze(t3).shape}")
- 运行结果:
t1:tensor([0, 1, 2, 3])
t1_:tensor([[0, 1, 2, 3]])
t1_的维度:2
t1_1:tensor([0, 1, 2, 3])
t1_1的维度:1
t2:tensor([[[[0.],[0.],[0.]]]])
t2 降维:tensor([0., 0., 0.])
t3:tensor([[[[[[1., 1.]],[[1., 1.]]],[[[1., 1.]],[[1., 1.]]],[[[1., 1.]],[[1., 1.]]]]]])
t3 降维:tensor([[[1., 1.],[1., 1.]],[[1., 1.],[1., 1.]],[[1., 1.],[1., 1.]]]),形状:torch.Size([3, 2, 2])
- 示例解读:
squeeze()
的核心用途是精简张量结构:在实际运算中,有时会因操作产生不必要的维度(如reshape(1, n)
),这些维度不携带有效信息,反而会增加运算复杂度,squeeze()
可以快速去除它们。
6.2 unsqueeze()
手动升维
torch.unsqueeze(input, dim)
- 功能:在指定维度上为张量增加一个大小为1的新维度,返回一个新的张量(原张量的视图)。
- 参数说明:
input
:输入的张量。dim
:指定要增加新维度的位置,新维度的大小为1。
t4 = torch.zeros(1, 2, 1, 2)
print(f"t4:{t4}")
print(f"在第1个维度上升高一个维度:{torch.unsqueeze(t4, 0)},形状:{torch.unsqueeze(t4, 0).shape}")
print(f"在第3个维度上升高一个维度:{torch.unsqueeze(t4, 2)},形状:{torch.unsqueeze(t4, 2).shape}")
- 运行结果:
t4:tensor([[[[0., 0.]],[[0., 0.]]]])
在第1个维度上升高一个维度:tensor([[[[[0., 0.]],[[0., 0.]]]]]),形状:torch.Size([1, 1, 2, 1, 2])
在第3个维度上升高一个维度:tensor([[[[[0., 0.]]],[[[0., 0.]]]]]),形状:torch.Size([1, 2, 1, 1, 2])
- 示例解读:
unsqueeze()
的核心用途是适配运算维度要求:例如在深度学习中,模型输入通常需要包含批次维度(batch dim),如果单个样本是 (3,32,32) 的图像张量,可以用unsqueeze(0)
将其转为 (1,3,32,32),表示1 个样本,以符合模型的输入格式。
微语录:犹豫了,谁都救不了。——《罗小黑战记》