当前位置: 首页 > news >正文

pytorch 笔记:index_select

1 基本使用方法

index_select 是 PyTorch 中的一个非常有用的函数,允许从给定的维度中选择指定索引的张量值

torch.index_select(input, dim, index, out=None) -> Tensor
input从中选择数据的源张量
dim从中选择数据的维度
index

一个 1D 张量,包含你想要从 dim 维度中选择的索引

此张量应该是 LongTensor 类型

out

一个可选的参数,用于指定输出张量。

如果没有提供,将创建一个新的张量。

2 举例

import torch
import numpy as npx = torch.tensor(np.arange(16).reshape(4,4))
index=torch.LongTensor([1,3])
x
'''
tensor([[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11],[12, 13, 14, 15]], dtype=torch.int32)
'''torch.index_select(x,dim=0,index=index)
'''
tensor([[ 4,  5,  6,  7],[12, 13, 14, 15]], dtype=torch.int32)
'''torch.index_select(x,dim=1,index=index)
'''
tensor([[ 1,  3],[ 5,  7],[ 9, 11],[13, 15]], dtype=torch.int32)
'''

3 index_select保存梯度

import torch
import numpy as npx = torch.tensor(np.arange(16).reshape(4,4),dtype=torch.float32, requires_grad=True)
index=torch.LongTensor([1,3])
x
'''
tensor([[ 0.,  1.,  2.,  3.],[ 4.,  5.,  6.,  7.],[ 8.,  9., 10., 11.],[12., 13., 14., 15.]], requires_grad=True)
'''torch.index_select(x,dim=0,index=index)
'''
tensor([[ 4.,  5.,  6.,  7.],[12., 13., 14., 15.]], grad_fn=<IndexSelectBackward0>)
'''torch.index_select(x,dim=1,index=index)
'''
tensor([[ 1.,  3.],[ 5.,  7.],[ 9., 11.],[13., 15.]], grad_fn=<IndexSelectBackward0>)
'''

http://www.lryc.cn/news/211043.html

相关文章:

  • 面试算法43:在完全二叉树中添加节点
  • Python算法例3 检测2的幂次
  • 线扫相机DALSA--采集卡Base模式设置
  • Gitee 发行版
  • python面向对象
  • Go基础——数组、切片、集合
  • Error: no matching distribution found for tensorflow-cpu==2.6.*
  • nginx 进程模型
  • TypeScript - 枚举类型 -字符型枚举
  • 分布式锁-Redis红锁解决方案
  • 【Ubuntu 终端终结者Ctrl shift e无法垂直分页解决办法】
  • Error: error:0308010C:digital envelope routines::unsupported
  • RTMP在智能眼镜行业应用方案有哪些?
  • 【每日一题】合并两个有序数组
  • MySQL---表的增查改删(CRUD进阶)
  • 《HelloGitHub》第 91 期
  • jvm线上异常排查流程
  • python项目之酒店客房入侵检测系统的设计与实现
  • C++ 学习系列 -- 标准库常用得 algorithm function
  • [论文笔记]E5
  • k8s 1.28版本:使用StorageClass动态创建PV,SelfLink 问题修复
  • 漏洞复现-dedecms文件上传(CVE-2019-8933)
  • vue分片上传
  • 【大数据Hive】hive 表数据优化使用详解
  • 京东平台数据分析(京东销量):2023年9月京东吸尘器行业品牌销售排行榜
  • 基于springboot实现休闲娱乐代理售票平台系统项目【项目源码+论文说明】计算机毕业设计
  • jvm对象内存划分
  • 网络原理之TCP/IP
  • Docker:数据卷挂载
  • 你会处理 go 中的 nil 吗