PyTorch topk() 用法详解:取最大值
torch.topk(input, k)
返回张量中最大的 k 个元素以及它们在原张量中的 索引。
函数原型
torch.topk(input, k, dim=None, largest=True, sorted=True)
参数说明:
参数 | 说明 |
---|---|
input | 输入张量 |
k | 要取出的前 k 个值 |
dim | 指定沿哪个维度取值(默认是最后一维) |
largest | 是否取最大值(默认是 True ,为 False 时返回最小值) |
sorted | 返回的结果是否排序(默认是 True ,按值从大到小) |
示例:二维张量中使用 topk()
(dim=0
vs dim=1
)
我们来通过一个具体的 3x3
张量示例,观察在不同维度上使用 topk()
的结果。
import torch# 创建一个 3x3 的二维张量
x = torch.tensor([[0.1, 0.8, 0.6],[0.9, 0.2, 0.3],[0.5, 0.4, 0.7]
])
沿行取 Top-k:dim=1
print(torch.topk(x, k=2, dim=1))# 输出:
# values=tensor([[0.8000, 0.6000],
# [0.9000, 0.3000],
# [0.7000, 0.5000]]),
# indices=tensor([[1, 2],
# [0, 2],
# [2, 0]]))
每一行分别取出前两个最大值及其列索引
沿列取 Top-k:dim=0
print(torch.topk(x, k=2, dim=0))# 输出:
# values=tensor([[0.9000, 0.8000, 0.7000],
# [0.5000, 0.4000, 0.6000]]),
# indices=tensor([[1, 0, 2],
# [2, 2, 0]]))
每一列分别取出前两个最大值及其对应的“行号”。
理解维度的直觉图示
dim=1
:按行取 top-k(对每一行,从左往右选 k 个最大值)dim=0
:按列取 top-k(对每一列,从上往下选 k 个最大值)
操作 | 意图 | 方向 |
---|---|---|
topk(x, k, dim=1) | 每行选前 k 个最大 | ⟶ 横向 |
topk(x, k, dim=0) | 每列选前 k 个最大 | ⬇ 纵向 |
topk 与largest、sorted操作的组合
1. 取最小值:largest=False
d = torch.tensor([5, 3, 8, 1, 2])
smallest, indices = torch.topk(d, k=2, largest=False)
print("前2小的值:", smallest)
# 输出: tensor([1, 2])
2. 不排序:sorted=False
e = torch.tensor([3, 1, 4, 2, 5])
values, indices = torch.topk(e, k=3, sorted=False)
print("前3大的值(未排序):", values)
# 输出: tensor([3, 4, 5])
print("对应索引:", indices)
# 输出: tensor([0, 2, 4])