当前位置: 首页 > news >正文

torch.gather(...)

1. Abstract

对于 pytorch 中的函数

torch.gather(input,  # (Tensor) the source tensordim,    # (int)    the axis along which to indexindex,  # (LongTensor) the indices of elements to gather*,sparse_grad=False,out=None
) → Tensor

有点绕,很多博客画各种图讲各种故事来解释如何input 张量中 gather 位置 index 处的值,乱七八糟,我是都没看明白。所以去官网看了文档:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

从这三行看,意思还是很明晰的:输出 out 和输入 input 之间的差别就是,把相应位置(dim)的下标替换成 index[i][j][k]dim=0,1,2 分别对应替换的位置0,1,2。但这不够直观!

【注】从上面三行代码可以看出,indexinput 的维度尺寸是一样的,即 len(index.shape) == len(input.shape),但不一定是相同的形状:index.shape[dim] ≠ input.shape[dim](其他维度的形状必须满足 index.shape <= input.shape)。

2. 图解

2.1 一维向量

先从简单的一维向量看看:

x = torch.tensor([3, 4, 5, 6, 7])

按规则看,out[i] = input[index[i]] # dim == 0,即,从向量里选取指定位置 index[i] 处的数字,放到输出向量 out[i] 处。这个很好理解,pythonnumpypytorch 都有这样的语法:

x = torch.randn(3)
index = torch.randint(low=0, high=3, size=(5,))
y = x[index]
print(x)
print(index)
print(y)
### output ###
tensor([ 0.8797,  0.2459, -0.1312])
tensor([2, 0, 2, 2, 0])
tensor([-0.1312,  0.8797, -0.1312, -0.1312,  0.8797])

torch.gather(...) 函数,就是这样的:

x = torch.tensor([3, 4, 5, 6, 7])
index = torch.tensor([4, 4, 1, 1, 0, 3])
out = torch.gather(x, dim=0, index=index)
### output ###
tensor([7, 7, 4, 4, 3, 6])

举例来说,上面的 index[4] = 0,那么它会寻找 input[index[4]] = input[0] = 3,然后放入 out[4]。这就是英文单词 gather 的意思。

index 的长度是不受限制的,即 gather 多少元素都可以。

小结:在一维向量下,out = torch.gather(x, dim=0, index=index) 等价于 out = x[index]

2.2 二维矩阵

往上升一个维度,看看对二维矩阵实施 gather 函数的操作:

x = torch.tensor([[3, 4, 5, 6, 7], [9, 8, 7, 6, 5]])
idx = torch.randint(low=0, high=5, size=(2, 6))
y = torch.gather(x, dim=1, index=idx)
print(x)
print(idx)
print(y)
### output ###
tensor([[3, 4, 5, 6, 7],[9, 8, 7, 6, 5]])
tensor([[4, 4, 1, 1, 0, 3],[0, 1, 2, 1, 4, 1]])
tensor([[7, 7, 4, 4, 3, 6],[9, 8, 7, 8, 5, 8]])

按规则看,out[i][j] = input[i][index[i][j]] # dim == 1,即,从向量 input[i] 里选取指定位置 index[i][j] 处的数字,放到输出向量 out[i][j] 处。也许多了一个维度就有点绕了,但仔细观察,我们可以假定 i = 0,此时:

out[0][j] = input[0][index[0][j]]  # 对应上图的左侧

若假定 i = 1,则:

out[1][j] = input[1][index[1][j]]  # 对应上图的右侧

即,输出 out[i] 是对输入 imput[i] 执行了一次与一维向量时一样的操作,其中下标是 index[i]。在二维矩阵上的 gather 操作,不过是并行地执行了多个一维向量的 gather

上面是 dim = 1 时的情况,是沿着矩阵的进行 gather,当 dim = 0 时,就是沿着进行 gather

out[i][0] = input[index[i][0]][0]  # dim == 0
out[i][1] = input[index[i][1]][1]
...


也就是并行地执行多个列向量gather,每列 index 是一个并行分支,并行分支的数量可以小于 input 的列数,但不能超过,超过的话,它 gather 哪一列呢?

小结:二维矩阵的 gather 操作就是并行地执行了多个一维向量的 gather 操作;dim=1 按行 gatherdim=0 按列 gather

2.3 高维张量

弄懂一维到二维的 gather,更高维的操作也就清晰了,就是画图有一点难画。假设

x = tensor([[[ 0,  1,  2,  3,  4],[ 5,  6,  7,  8,  9]],[[10, 11, 12, 13, 14],[15, 16, 17, 18, 19]],[[20, 21, 22, 13, 24],[25, 26, 27, 28, 29]]])

则当 dim == 0 时,是沿着第一维进行 gather 的,那么 index.shape[0] (一个并行分支 gather 的元素的数量) 可为任意数,这里设置为 4,其他 index.shape[i≠0] <= input.shape[i≠0]

index = tensor([[[1, 2, 2],[2, 2, 0]],[[0, 0, 1],[1, 0, 1]],[[2, 0, 0],[0, 1, 2]],[[1, 1, 0],[0, 0, 0]]])

index.shape == (4, 2, 3),执行:

y = torch.gather(x, dim=0, index=index)

的示意图如下:

只画了看得见的前两列(两个并行 gather 分支)。红色和绿色箭头表示两列下标沿着 dim=0 进行 gather 操作,每一列和一维向量的 gather 是一样的,只不过这里有 2*3 个列。

再往高维拓展,也是一样,都是从基本的一维向量 gather 拓到并行 gather

http://www.lryc.cn/news/262615.html

相关文章:

  • vscode如何开发微信小程序?JS与TS的主要区别?
  • 产品入门第五讲:Axure交互和情境
  • Python 自动化之收发邮件(一)
  • Flutter开发笔记 —— sqflite插件数据库应用
  • OxLint 发布了,Eslint 何去何从?
  • 第一次使用ThreadPoolExecutor处理业务
  • Sharding-Jdbc(6):Sharding-Jdbc日志分析
  • centos安装了curl却报 -bash: curl: command not found
  • Re58:读论文 REALM: Retrieval-Augmented Language Model Pre-Training
  • java的json解析
  • Spring事务失效的几种情况
  • filter的用法与使用场景:筛选数据
  • ClickHouse(18)ClickHouse集成ODBC表引擎详细解析
  • 网络攻击(一)--安全渗透简介
  • 视频号小店资金需要多少?
  • 机器学习项目精选 第一期:超完整数据科学资料合集
  • 档案数字化管理可以提供什么服务?
  • 第一周:AI产品经理跳槽准备工作
  • 基于核心素养高中物理“深度学习”策略及其教学研究课题论证设计方案
  • 通过 Java 17、Spring Boot 3.2 构建 Web API 应用程序
  • go原生http开发简易blog(一)项目简介与搭建
  • [足式机器人]Part4 南科大高等机器人控制课 Ch09 Dynamics of Open Chains
  • 概率论复习
  • ES客户端RestHighLevelClient的使用
  • GitHub入门命令介绍
  • EasyExcel 简单导入
  • Termux搭建nodejs环境
  • 喜报丨迪捷软件入选2023年浙江省信息技术应用创新典型案例
  • C语言连接zookeeper客户端(不能完全参考官网教程)
  • python排序