当前位置: 首页 > news >正文

y _hat[ [ 0, 1], y ]语法——pytorch张量花式索引

目录

1. y _hat[ [ 0, 1]例子

2.pytorch花式索引

(1)简单行、列索引

(2)列表索引

(3)范围索引 

 (4)布尔索引

(5)多维索引 

3.张量拼接 

(1)torch.cat 函数的使用

 (2)torch.stack 函数的使用


1. y _hat[ [ 0, 1]例子

import torch
y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]

简单阐述我对第四行代码的理解

y_hat是一个2*3的数组
y_hat[[0,1],y]中的[0,1]指的是第一行和第二行的索引,后面的y等价于[0,2]。那么可以这么理解y_hat[0,0]和y_hat[1,2]。最后的结果也证明了我的理解。

2.pytorch花式索引

(1)简单行、列索引
import torchdata = torch.randint(0, 10, [4, 5])  # 四行五列的二维张量
print(data)
print(data[2])     # 获取第三行数据,返回一维张量
print(data[:, 1])  # 获取第二列数据,返回一维张量
print(data[1, 2])  # 获取第二行的第三列数据,返回零维张量
print(data[1][2])  # 同上
(2)列表索引
import torchdata = torch.randint(0, 10, [4, 5])  # 四行五列的二维张量
print(data)
print(data[[1,0,2])                # 返回下标为1行、0行、2行共三行数据组成的3行5列的二维张量
print(data[[0,1,3], [3,2,4]])      # 返回下标为0行3列、1行2列、3行4列三个数据组成的一维张量
print(data[[[0],[1]], [[3],[4]]])  # 返回下标为0行3列、1行4列两个数据组成的2行1列的二维张量
print(data[[0,1], [[3],[4]]])      # 返回下标为0行3列、1行3列、0行4列、1行4列四个数据组成的2行2列的二维张量
print(data[[0,1], [[1,2],[0,4]]])  # 返回下标为0行1列、1行2列、0行0列、1行4列四个数据组成的2行2列的二维张量
print(data[[[1],[0]], [3,4]])      # 返回下标为1行3列、1行4列、0行3列、0行4列四个数据组成的2行2列的二维张量
print(data[[[1,3],[0,2]], [3,4]])  # 返回下标为1行3列、3行4列、0行3列、2行4列四个数据组成的2行2列的二维张量
(3)范围索引 
import torchdata = torch.randint(0, 10, [4, 5])  # 四行五列的二维张量
print(data)
print(data[:3, 4])   # 返回前三行的第五列数据组成的一维张量
print(data[:3, [0,2,4]])  # 返回前三行的第一三五列数据组成的二维张量
print(data[:3, :4])  # 返回前三行的前四列数据组成的二维张量
print(data[2:, :4])  # 返回第三行到末行的前四列数据组成的二维张量
 (4)布尔索引
import torchdata = torch.randint(0, 10, [4, 5])  # 四行五列的二维张量
print(data)
print(data[data > 5])  # 返回所有大于5的元素组成的一维张量
print(data[[True,False,True,False]])  # 返回第一行与第三行数据组成的二维张量
print(data[1:, [True,False,True,False,True]])  # 返回第二行到末行的第一三五列数据组成的二维张量
print(data[data[:, 2] > 5])  # 返回第三列大于5的行数据组成的二维张量
print(data[:, data[1] > 5])  # 返回第二行大于5的列数据组成的二维张量
(5)多维索引 
data = torch.randint(0, 10, [3, 4, 5])  # 三片四行五列的三维张量
print(data)
print(data[0, :, :])  # 返回第一片所有数据,四行五列的二维张量
print(data[:, 0, :])  # 返回所有片的第一行数据,三行五列的二维张量
print(data[:, :, 0])  # 返回所有片的第一列数据,三行四列的二维张量

3.张量拼接 

(1)torch.cat 函数的使用
import torchdata1 = torch.randint(0, 10, [3, 5, 4])
data2 = torch.randint(0, 10, [3, 5, 4])
print(data1)
print(data2)new_data = torch.cat([data1, data2], dim=0)  # 1. 按0维度拼接
print(new_data)  # shape:torch.Size([6, 5, 4])new_data = torch.cat([data1, data2], dim=1)  # 2. 按1维度拼接
print(new_data)  # shape:torch.Size([3, 10, 4])new_data = torch.cat([data1, data2], dim=2)  # 3. 按2维度拼接
print(new_data)  # shape:torch.Size([3, 5, 8])
 (2)torch.stack 函数的使用
import torchdata1= torch.randint(0, 10, [4, 5])
data2= torch.randint(0, 10, [4, 5])
print(data1)
print(data2)new_data = torch.stack([data1, data2], dim=0)  # 在0维度叠加,升维!
print(new_data)  # shape:torch.Size([2, 4, 5])new_data = torch.stack([data1, data2], dim=1)  # 在1维度叠加,升维!
print(new_data)  # shape:torch.Size([4, 2, 5])new_data = torch.stack([data1, data2], dim=2)  # 在2维度叠加,升维!
print(new_data)  # shape:torch.Size([4, 5, 2])

http://www.lryc.cn/news/190270.html

相关文章:

  • 高级岗位面试问题
  • 区块链游戏的开发框架
  • Windows Nginx 服务器部署(保姆级)
  • 常用的Linux命令及其用法
  • linux总结
  • java - 设计模式 - 状态模式
  • c/c++--编译指令(预处理之后) #pragma
  • 黑马JVM总结(三十二)
  • 接口自动化测试框架【reudom】
  • 【数据库问题】删除数据库失败,提示:there is 1 other session using the database
  • 【技术干货】如何快速创建商用照明 OEM APP?
  • 阿里云ModelScope 是一个“模型即服务”(MaaS)平台
  • Nodejs内置模块process
  • Vue2 修改了数组哪些方法,为什么
  • 均值滤波算法及例程
  • 拥抱产业发展机遇 兑现5G商业价值
  • Layui合计自定义列
  • Tomcat自启动另一种方法
  • C语言,标志法
  • 适合自学的网络安全基础技能“蓝宝书”:《CTF那些事儿》
  • 软件设计师学习笔记12-数据库的基本概念+数据库的设计过程+概念设计+逻辑设计
  • distcc分布式编译
  • Java面试题-0919
  • WPF列表性能提高技术
  • 掌握 BERT:自然语言处理 (NLP) 从初级到高级的综合指南(2)
  • 【算法优选】 二分查找专题——贰
  • SQL 的优化
  • 华为云云耀云服务器L实例评测|华为云上的CentOS性能监测与调优指南
  • Go If流程控制与快乐路径原则
  • yolov8 strongSORT多目标跟踪工具箱BOXMOT