当前位置: 首页 > news >正文

PyTorch topk() 用法详解:取最大值

torch.topk(input, k) 返回张量中最大的 k 个元素以及它们在原张量中的 索引

函数原型

torch.topk(input, k, dim=None, largest=True, sorted=True)

参数说明:

参数说明
input输入张量
k要取出的前 k 个值
dim指定沿哪个维度取值(默认是最后一维)
largest是否取最大值(默认是 True,为 False 时返回最小值)
sorted返回的结果是否排序(默认是 True,按值从大到小)

示例:二维张量中使用 topk()dim=0 vs dim=1

我们来通过一个具体的 3x3 张量示例,观察在不同维度上使用 topk() 的结果。

import torch# 创建一个 3x3 的二维张量
x = torch.tensor([[0.1, 0.8, 0.6],[0.9, 0.2, 0.3],[0.5, 0.4, 0.7]
])

沿行取 Top-k:dim=1

print(torch.topk(x, k=2, dim=1))# 输出:
# values=tensor([[0.8000, 0.6000],
#         [0.9000, 0.3000],
#         [0.7000, 0.5000]]),
# indices=tensor([[1, 2],
#         [0, 2],
#         [2, 0]]))

每一行分别取出前两个最大值及其列索引

沿列取 Top-k:dim=0

print(torch.topk(x, k=2, dim=0))#  输出:
# values=tensor([[0.9000, 0.8000, 0.7000],
#        [0.5000, 0.4000, 0.6000]]),
# indices=tensor([[1, 0, 2],
#        [2, 2, 0]]))

每一列分别取出前两个最大值及其对应的“行号”。

理解维度的直觉图示

  • dim=1按行取 top-k(对每一行,从左往右选 k 个最大值)
  • dim=0按列取 top-k(对每一列,从上往下选 k 个最大值)
操作意图方向
topk(x, k, dim=1)每行选前 k 个最大⟶ 横向
topk(x, k, dim=0)每列选前 k 个最大⬇ 纵向

topk 与largest、sorted操作的组合

1. 取最小值:largest=False
d = torch.tensor([5, 3, 8, 1, 2])
smallest, indices = torch.topk(d, k=2, largest=False)
print("前2小的值:", smallest) 
# 输出: tensor([1, 2])

2. 不排序:sorted=False
e = torch.tensor([3, 1, 4, 2, 5])
values, indices = torch.topk(e, k=3, sorted=False)
print("前3大的值(未排序):", values)  
# 输出: tensor([3, 4, 5])
print("对应索引:", indices)         
# 输出: tensor([0, 2, 4])

http://www.lryc.cn/news/574493.html

相关文章:

  • Gym安装
  • 数据结构day2
  • 数组题解——​合并区间【LeetCode】
  • 使用 PyAEDT 设计参数化对数周期偶极子天线 LPDA
  • 如何解决TCP传输的“粘包“问题
  • HTTP面试题——缓存技术
  • Qt面试题汇总
  • 记录一下小程序城市索引栏开发经历
  • ✨从零搭建 Ubuntu22.04 + Python3.11 + PyTorch2.5.1 GPU Docker 镜像并上传 Docker Hub
  • Rocky8使用gvm配置Go多版本管理的微服务开发环境
  • uni-app项目实战笔记24--uniapp实现图片保存到手机相册
  • spring01-简介
  • 618风控战升级,瑞数信息“动态安全+AI”利剑出鞘
  • window显示驱动开发—DirectX 图形基础结构 DDI
  • 【CS创世SD NAND征文】基于全志V3S与CS创世SD NAND的物联网智能路灯网关数据存储方案
  • taro小程序,tailwindcss的bg-x-x,背景颜色不生效,只有自定义的写法颜色才生效
  • C++修炼:异常
  • 解码成都芯谷金融中心文化科技产业园:文化+科技双轮驱动
  • Qt 中使用 gtest 做单元测试
  • 一文读懂微观测量:光学3D轮廓仪与共聚焦显微成像的结合应用
  • cherry-pick除了使用命令,有没有什么工具可以使用,或者更高效的方法
  • Linux 文件 I/O 与标准 I/O 缓冲机制详解
  • Java面试中被深挖过的线程问题
  • 对手机屏中断路和短路的单元进行切割或熔接,实现液晶线路激光修复原理
  • Luckysheet Excel xlsx 导入导出互相转换
  • 02-Linux内核源码编译
  • CentOS 7 编译安装Nginx 1.27.5完整指南及负载均衡配置
  • MinIO中视频转换为HLS协议并进行AES加密
  • Python Polars库详解:高性能数据处理的新标杆
  • pyqt多界面