飞桨paddle API函数scatter详解
飞桨的scatter函数,是通过基于 updates
来更新选定索引 index
上的输入来获得输出,具体官网api文档见:
scatter-API文档-PaddlePaddle深度学习平台
官网给的例子如下:
>>> import paddle>>> x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32')>>> index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')>>> updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32')>>> output1 = paddle.scatter(x, index, updates, overwrite=False)>>> print(output1)Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,[[3., 3.],[6., 6.],[1., 1.]])>>> output2 = paddle.scatter(x, index, updates, overwrite=True)>>> # CPU device:>>> # [[3., 3.],>>> # [4., 4.],>>> # [1., 1.]]>>> # GPU device maybe have two results because of the repeated numbers in index>>> # result 1:>>> # [[3., 3.],>>> # [4., 4.],>>> # [1., 1.]]>>> # result 2:>>> # [[3., 3.],>>> # [2., 2.],>>> # [1., 1.]]
但是如果是初学者,看官网的例子可能还是无法明白scatter的运算方式,下面就结合一个更加明白的例子来说明:
import paddle
x = paddle.to_tensor([[100, 200], [300, 400], [500, 600]], dtype='float32')
index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')
updates = paddle.to_tensor([[10, 11], [21, 22], [33, 34], [40, 41]], dtype='float32')output1 = paddle.scatter(x, index, updates, overwrite=False)
print(output1)
output2 = paddle.scatter(x, index, updates, overwrite=True)
print(output2)
输出结果:
Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,[[33., 34.],[61., 63.],[10., 11.]])
Tensor(shape=[3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,[[33., 34.],[40., 41.],[10., 11.]])
scatter详解
输入是三个值:源值x ,索引index, 变量updates,分析函数输出,可以得出以下结论:
1 scatter函数的输出shape是和x一致的
2 函数输出的值跟x没关系
3 函数输出值跟变量updates值有关
4 输出值跟updates的有关,具体取值的索引跟index有关
具体来说,就是不需要x的值,只使用了它的维度信息,然后根据索引,将变量updates的值填入x的维度中。比如index值是[2, 1, 0, 1],第一位是2,那么就把updates的第一组数[10, 11](也就是updates[0])取出来放到x[2]里;index第二位是1 ,就把的updates的第2组数[21, 22](也就是updates[1])取出来放到x[1]里。以此类推,index第三位是0, 那么就把updates的第3组数[33, 34](也就是updates[2])放到x[0]里。
到了index第四位数,它是1 ,那么就需要把updates的第四组数[40, 41](也就是updates[3]),放入到x[1]中 。这时候有个问题,就是前面x[1]中已经放入了[21,22]。这时候就看函数的overwrite参数的设置了,如果设置overwrite=True ,那么直接用现在的值[40, 41]取代以前的值,最终函数返回结果就是[[33, 34], [40, 41], [10, 11]] 。如果函数设为overwrite=False ,那么就将值[40, 41]与以前的x[1](21, 22)相加,结果是[61, 63],最终返回值就是[[33, 34], [61, 63], [10, 11]]
好了,这样大家就明白scatter的运算机制了吧?
小贴士
飞桨官网给出了scatter函数的python代码实现,其中因为用了巧妙的思路来提高速度,可读性略有下降:
>>> import paddle>>> #input:>>> x = paddle.to_tensor([[1, 1], [2, 2], [3, 3]], dtype='float32')>>> index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')>>> # shape of updates should be the same as x>>> # shape of updates with dim > 1 should be the same as input>>> updates = paddle.to_tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32')>>> overwrite = False>>> # calculation:>>> if not overwrite:... for i in range(len(index)):... x[index[i]] = paddle.zeros([2])>>> for i in range(len(index)):... if (overwrite):... x[index[i]] = updates[i]... else:... x[index[i]] += updates[i]>>> # output:>>> out = paddle.to_tensor([[3, 3], [6, 6], [1, 1]])>>> print(out.shape)[3, 2]
scatter的运算机制不管overwrite是否为True,x的值都不参与运算,理论上应该都清除,也就是在循环里置0 :x[index[i]] = paddle.zeros([2])
实际上如果overwrite是True,那么在赋值的时候本身可以直接写入 x[index[i]] = updates[i],这样就可以省略x[index[i]] = paddle.zeros([2])这句,这就是为什么这段置0代码放到了条件:if not overwrite: 这句里面的原因。