当前位置: 首页 > news >正文

飞桨paddle API函数scatter详解

飞桨的scatter函数,是通过基于 updates 来更新选定索引 index 上的输入来获得输出,具体官网api文档见:

scatter-API文档-PaddlePaddle深度学习平台

官网给的例子如下:

 >>> import paddle>>> x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32')>>> index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')>>> updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32')>>> output1 = paddle.scatter(x, index, updates, overwrite=False)>>> print(output1)Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,[[3., 3.],[6., 6.],[1., 1.]])>>> output2 = paddle.scatter(x, index, updates, overwrite=True)>>> # CPU device:>>> # [[3., 3.],>>> #  [4., 4.],>>> #  [1., 1.]]>>> # GPU device maybe have two results because of the repeated numbers in index>>> # result 1:>>> # [[3., 3.],>>> #  [4., 4.],>>> #  [1., 1.]]>>> # result 2:>>> # [[3., 3.],>>> #  [2., 2.],>>> #  [1., 1.]]

但是如果是初学者,看官网的例子可能还是无法明白scatter的运算方式,下面就结合一个更加明白的例子来说明:

import paddle
x = paddle.to_tensor([[100, 200], [300, 400], [500, 600]], dtype='float32')
index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')
updates = paddle.to_tensor([[10, 11], [21, 22], [33, 34], [40, 41]], dtype='float32')output1 = paddle.scatter(x, index, updates, overwrite=False)
print(output1)
output2 = paddle.scatter(x, index, updates, overwrite=True)
print(output2)

输出结果:

Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,[[33., 34.],[61., 63.],[10., 11.]])
Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,[[33., 34.],[40., 41.],[10., 11.]])

scatter详解

输入是三个值:源值x ,索引index, 变量updates,分析函数输出,可以得出以下结论:

1 scatter函数的输出shape是和x一致的

2 函数输出的值跟x没关系

3 函数输出值跟变量updates值有关

4 输出值跟updates的有关,具体取值的索引跟index有关

具体来说,就是不需要x的值,只使用了它的维度信息,然后根据索引,将变量updates的值填入x的维度中。比如index值是[2, 1, 0, 1],第一位是2,那么就把updates的第一组数[10, 11](也就是updates[0])取出来放到x[2]里;index第二位是1 ,就把的updates的第2组数[21, 22](也就是updates[1])取出来放到x[1]里。以此类推,index第三位是0, 那么就把updates的第3组数[33, 34](也就是updates[2])放到x[0]里。

到了index第四位数,它是1 ,那么就需要把updates的第四组数[40, 41](也就是updates[3]),放入到x[1]中 。这时候有个问题,就是前面x[1]中已经放入了[21,22]。这时候就看函数的overwrite参数的设置了,如果设置overwrite=True ,那么直接用现在的值[40, 41]取代以前的值,最终函数返回结果就是[[33, 34], [40, 41], [10, 11]] 。如果函数设为overwrite=False ,那么就将值[40, 41]与以前的x[1](21, 22)相加,结果是[61, 63],最终返回值就是[[33, 34], [61, 63], [10, 11]]

好了,这样大家就明白scatter的运算机制了吧? 

小贴士

飞桨官网给出了scatter函数的python代码实现,其中因为用了巧妙的思路来提高速度,可读性略有下降:

 >>> import paddle>>> #input:>>> x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32')>>> index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')>>> # shape of updates should be the same as x>>> # shape of updates with dim > 1 should be the same as input>>> updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32')>>> overwrite = False>>> # calculation:>>> if not overwrite:...     for i in range(len(index)):...         x[index[i]] = paddle.zeros([2])>>> for i in range(len(index)):...     if (overwrite):...         x[index[i]] = updates[i]...     else:...         x[index[i]] += updates[i]>>> # output:>>> out = paddle.to_tensor([[3, 3], [6, 6], [1, 1]])>>> print(out.shape)[3, 2]

scatter的运算机制不管overwrite是否为True,x的值都不参与运算,理论上应该都清除,也就是在循环里置0 :x[index[i]] = paddle.zeros([2])

实际上如果overwrite是True,那么在赋值的时候本身可以直接写入 x[index[i]] = updates[i],这样就可以省略x[index[i]] = paddle.zeros([2])这句,这就是为什么这段置0代码放到了条件:if not overwrite: 这句里面的原因。

http://www.lryc.cn/news/424074.html

相关文章:

  • RCE漏洞复现
  • Qt QTabWidget之创建标签页的多页面切换
  • 【RISC-V设计-14】- RISC-V处理器设计K0A之打印输出
  • 时序预测|基于变分模态分解-时域卷积-双向长短期记忆-注意力机制多变量时间序列预测VMD-TCN-BiLSTM-Attention
  • Python知识点:如何使用Godot与Python进行游戏脚本编写
  • Spring MVC数据绑定和响应学习笔记
  • Vulnhub JIS-CTF靶机详解
  • FPGA资源评估
  • REST framework中Views API学习
  • Vue(四)——总结
  • 计算机毕业设计 招生宣传管理系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试
  • 练习题PHP5.6+变长参数 ⇒ usort回调后门 ⇒ 任意代码执行
  • EPLAN关于PLC的输入输出模块绘制
  • 【Linux】sersync 实时同步
  • Unity 资源分享 之 恐龙Ceratosaurus资源模型携 82 个动画来袭
  • 【AI绘画】 学习内容简介
  • 树形结构查找(B树、B+树)
  • 网络通信(TCP/UDP协议 三次握手四次挥手 )
  • C# ADO.Net 通用按月建表插入数据
  • 19-ESP32-C3加大固件储存区
  • 【STL】stack/queue 容器适配器 deque
  • (回溯) LeetCode 17. 电话号码的组合
  • Ghidra:开源软件逆向工程框架
  • Spring AI 更新:支持OpenAI的结构化输出,增强对JSON响应的支持
  • java.util.ConcurrentModificationException 并发修改异常
  • Flask数据库操作(第四阶段)
  • C语言问答进阶--5、基本表达式和基本语句
  • uniapp3.0实现图片上传公用组件上传uni-file-picker,uni.uploadFile
  • Unity游戏开发002
  • MySQL基础练习题38-每位教师所教授的科目种类的数量