当前位置: 首页 > news >正文

无脑入门pytorch系列(四)—— scatter_

本系列教程适用于没有任何pytorch的同学(简单的python语法还是要的),从代码的表层出发挖掘代码的深层含义,理解具体的意思和内涵。pytorch的很多函数看着非常简单,但是其中包含了很多内容,不了解其中的意思就只能【看懂代码】,无法【理解代码】。

目录

  • 官方定义
  • demo
  • one-hot

官方定义

torch.tensor.scatter_是PyTorch中的一个函数,用于将指定索引处的值替换为给定的值。

函数定义:

Tensor.scatter_(dim, index, src, reduce=None) → Tensor

官方解释:

  • 将张量src中的所有值写入索引张量中指定的index处的self。

  • 对于src中的每个值,它的输出索引由其在src中的索引(dimension != dim)和在index中对应的值(dimension = dim)指定。

非常难以理解,十分抽象,从我个人的角度来说就是:

  • 第一个参数dim表示维度,即在第几维度处理数据,保持其它维度不变。
  • reduce参数是一个可选参数,用于指定如何在执行散射(scatter)操作时对重复的索引值进行合并或聚合。
  • index则是需要填充的列的索引,即根据维度从src中取对应的值填充到tensor中去。

怎么映射的,比如一个一个3维张量:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

官方的文档如下,TORCH.TENSOR.SCATTER_:

image-20230818104242738

即使如此理解起来也是很复杂,下面从例子中去理解:

demo

下面是一个官方文档给出的例子:

import torchsrc = torch.Tensor([[-1.0276,  0.2673, -1.1752, -0.8823],[-0.6447, -0.8256,  0.1542, -0.4242]])
print(src)output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])output = output.scatter(1, index, src)
print(output)

输出的结果:

image-20230818142004545

我们一步步理解代码:

  1. 首先,定义了一个src张量,后续output即从src中取值。
  2. 其次,定义了output,其值为二行五列的全零张量,后续对output进行修改。
  3. 接着,定义了index,即从src取值的索引。
  4. 最后,根据index从src取值填充到output中,即完成操作。

那么具体是如何取值的呢?

首先,dim = 1,意味着从维度值为1的地方取值,维度值为0的地方不变,那就是:

self[i][index[i][j]] = src[i][j]  # if dim == 1

具体来说:

i = 0, j = 0时,output[0][index[0][0]] = src[0][0],因为index[0][0] = 3,所以output[0][3] = src[0][0] = -1.0276,这时候我们检查输出的output值,确实是-1.0276

同理:

i = 0, j = 1: output[0][index[0][1]] = output[0][1] = src[0][1] = 0.2673

i = 0, j = 2: output[0][index[0][2]] = output[0][2] = src[0][2] = -1.1752

one-hot

作者在学习该函数时实在遇到one-hot编码时遇到的,而该函数在one-hot中应用很广:

index = torch.tensor([[3], [2], [0], [1]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)

image-20230818143854519

http://www.lryc.cn/news/130146.html

相关文章:

  • 【Spring源码】Spring扩展点及顺序
  • 广州华锐互动:3D数字孪生开发编辑器助力企业高效开发数字孪生应用
  • 【脚踢数据结构】图(纯享版)
  • [leetcode] 707 设计链表
  • JIRA:项目管理的秘密武器
  • ARM 作业1
  • 【解析postman工具的使用---基础篇】
  • Elasticsearch:如何在 Ubuntu 上安装多个节点的 Elasticsearch 集群 - 8.x
  • 记录win 7旗舰版 “VMware Alias Manager and Ticket Service‘(VGAuhService)启动失败。
  • git 开发环境配置
  • Tableau画图
  • nginx上web服务的基本安全优化、服务性能优化、访问日志优化、目录资源优化和防盗链配置简介
  • himall3.0商城源码
  • 【LeetCode75】第二十九题 删除链表的中间节点
  • Floyd(多源汇最短路)
  • Pycharm找不到Conda可执行文件路径(Pycharm无法导入Anaconda已有环境)
  • 国产之光:讯飞星火最新大模型V2.0
  • 通讯录实现【C语言】
  • pcl欧式聚类
  • macOS Ventura 13.5.1(22G90)发布(附黑/白苹果系统镜像地址)
  • 分布式监控平台——Zabbix
  • 【OpenGauss源码学习 —— 列存储(创建表)】
  • Jenkins 监控dist.zip文件内容发生变化 触发自动部署
  • Linux系列讲解 —— FTP协议的应用
  • Rancher-RKE-install 部署k8s集群
  • PHP8的正则表达式-PHP8知识详解
  • SpringCloud实用篇7——深入elasticsearch
  • uni-app 经验分享,从入门到离职(二)—— tabBar 底部导航栏实战篇
  • Java虚拟机(JVM):内存区域
  • 11 - git stash 开发中临时加塞了紧急任务怎么处理