当前位置: 首页 > news >正文

【taichi】利用 taichi 编写深度学习算子 —— 以提取右上三角阵为例

本文以取 (bs, n, n) 张量的右上三角阵并展平为向量 (bs, n*(n+1)//2)) 为例,展示如何用 taichi 编写深度学习算子。

在这里插入图片描述
如图,要把形状为 (bs,n,n)(bs,n,n)(bs,n,n) 的张量,转化为 (bs,n(n+1)2)(bs,\frac{n(n+1)}{2})(bs,2n(n+1)) 的向量。我们先写一个最简单的最慢的纯 python 循环实现方法

纯 python for 循环

def get_tensor_up_right_tri_slow(t):# t shape (bs, n, n)# out shape (bs, n*(n+1)//2)out = torch.zeros(t.shape[0], t.shape[1]*(t.shape[1]+1)//2)n = t.shape[1]# k = i*n + j - i*(i+1)//2for b in range(t.shape[0]):# 遍历右上三角阵,包括主对角线for i in range(t.shape[1]):for j in range(i, t.shape[1]):k = i*n + j - i*(i+1)//2out[b, k] = t[b, i, j]return out

可想而知,三层 python for 循环,必然是极慢的了。

转化为 taichi

在此基础上,稍微做一些修改,就可以得到我们的 taichi 版本函数

import taichi as titi.init(arch=ti.gpu)@ti.kernel
def get_tensor_up_right_tri(t: ti.types.ndarray(ndim=3, dtype=ti.f32), out: ti.types.ndarray(ndim=2, dtype=ti.f32)):# t shape (bs, n, n)# out shape (bs, n*(n+1)//2)n = t.shape[1]for b, i, j in t:# 遍历右上三角阵,包括主对角线if i <= j:k = i*n + j - i*(i+1)//2out[b, k] = t[b, i, j]

taichi 支持同时遍历多层循环,将三层循环改为一层循环后,和 python for 循环版本基本没有什么差别。taichi 将此函数转化为 CUDA 版本进行加速,从而提高运算速度。

http://www.lryc.cn/news/5712.html

相关文章:

  • 二进制 k8s 集群下线 worker 组件流程分析和实践
  • Bean的六种作用域
  • Http发展历史
  • 高级Java程序员必备的技术点,你会了吗?
  • 【暴力量化】查找最优均线
  • Java读取mysql导入的文件时中文字段出现�??的乱码如何解决
  • k8s核心概念—Pod Controller Service介绍——20230213
  • Tensorflow的数学基础
  • IT培训就是“包就业”吗?内行人这么看
  • 【算法】【数组与矩阵模块】顺时针旋转打印矩阵
  • Java中的锁概述
  • 微电影行业痛点解决方案
  • 使用Spring框架的好处是什么
  • 【表格单元格可编辑】vue-elementul简单实现table表格点击单元格可编辑,点击单元格变成弹框修改数据
  • vue3.0 响应式数据
  • uni-app ①
  • 20个 Git 命令玩转版本控制
  • SAP NetWeaver版本和SAP Kernel版本的确定
  • 面试23K字节测试开发岗被血虐,到底具有怎样的技术才算高级水平?
  • 智云通CRM:买对了吗——大客户采购的方案实施
  • 前后端开发过程中的跨域问题总结
  • 爬虫:栖落的电影网站,利用requests和re模块
  • 使用burpsuite抓包 + sql工具注入 dvwa靶场
  • 树与图中的dfs和bfs—— AcWing 846. 树的重心 AcWing 847. 图中点的层次
  • 从零开始学数据分析之数据分析概述
  • 十五载厚积薄发,电信级分布式数据库是这样炼成
  • Centos调整分区存储大小
  • 华为OD机试真题JAVA实现【单词接龙】真题+解题思路+代码(20222023)
  • Mapbox Style 规范
  • Java开发学习(五十)----MyBatisPlus快速开发之代码生成器解析