当前位置: 首页 > news >正文

【深度学习模型移植】用torch普通算子组合替代torch.einsum方法

     首先不得不佩服大模型的强大之处,在算法移植过程中遇到einsum算子在ONNX中不支持,因此需要使用普通算子替代。参考TensorRT - 使用torch普通算子组合替代torch.einsum爱因斯坦求和约定算子的一般性方法。可以写出简单的替换方法,但是该方法会导致训练时还是推理都很慢,并且会消耗大量显存,造成显存溢出的问题。。因此采用提问文心一言,没想到居然真的回答正确了。当然替换需要验证,不是全对的。
1.einsum(delta, A, ‘b l d_in, d_in n -> b l d_in n’) 的替换,以下两个方法均可以

deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
deltaA = torch.exp(delta.unsqueeze(dim=3)*A.unsqueeze(dim=0).unsqueeze(dim=0))
deltaA = torch.exp(delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1) * A)

2.einsum(x, C[:, i, :], ‘b d_in n, b n -> b d_in’),以下两个方法均可以

    y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')y = (x*C[:, i, :].unsqueeze(dim=1)).sum(dim=2)y = torch.matmul(C[:, i, :], x.transpose(-1, -2)).squeeze(1)

3.einsum(delta, B, u, ‘b l d_in, b l n, b l d_in -> b l d_in n’),以下两个方法均可以

deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
deltaB_u1 = delta.unsqueeze(dim=3)*B.unsqueeze(dim=2)*u.unsqueeze(dim=3)

下述方法是提问文心一言的办法,注意需要将答案的结果和einsum的结果进行对比,采用np.testing.assert_allclose(deltaB_u.numpy(),deltaB_u1.numpy(),rtol=1e-05,atol=1e-05)和print(deltaA.equal(deltaA_manual))均可以。

import torch
import numpy as np
from einops import rearrange, repeat, einsum
# 给定的张量
delta = torch.ones([1, 3, 2])
A = torch.ones([2, 4])
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
deltaA1 = torch.exp(delta.unsqueeze(dim=3)*A.unsqueeze(dim=0).unsqueeze(dim=0))
deltaA_manual = torch.exp(delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1) * A)
np.testing.assert_allclose(deltaA.numpy(),deltaA1.numpy(),rtol=1e-05,atol=1e-05)# 扩展 delta 的维度,以便它可以与 A 进行广播(broadcast)
# 这里我们使用 unsqueeze 和 repeat_interleave 来扩展维度
delta_expanded = delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1)
# 执行逐元素的乘法,然后取指数
deltaA_manual = torch.exp(delta_expanded * A)# 注意:deltaA_manual 的形状是 [1, 3, 2, 4],这与 einsum 的输出形状一致
print(deltaA.equal(deltaA_manual))
print(deltaA1.equal(deltaA_manual))

请添加图片描述
请添加图片描述
请添加图片描述

http://www.lryc.cn/news/319736.html

相关文章:

  • 鸿蒙 Harmony 初体验
  • Jmeter+ant,ant安装与配置
  • 【MySQL基础】MySQL基础操作三
  • 【K8s】肿么办??Kubernetes Secrets并不是Secret哟!!
  • 数星星 刷题笔记 (树状数组)
  • Windows→Linux,本地同步到服务器
  • Pycharm连接远程服务器Anoconda中的虚拟环境
  • 无人机自动返航算法实现与优化
  • 切面条-蓝桥杯?-Lua 中文代码解题第1题
  • WebRTC:真正了解 RTP 和 RTCP
  • vue实现双向绑定原理深度解析
  • C语言 —— memeove函数的模拟实现
  • <el-tab>样式自定义——一个可以触类旁通的小例子
  • XDP学习笔记
  • JavaScript进阶:js的一些学习笔记-4
  • 【可能是全网最丝滑的LangChain教程】三、快速入门LLMChain
  • Oracle Primavera Analytics 是什么,与P6的关系?
  • 在 Amazon Bedrock 上使用 Anthropic Claude 系统 Prompt
  • 【LeetCode】动态规划--题目练习
  • 【LeetCode热题100】101. 对称二叉树(二叉树)
  • VLC抓取m3u8视频
  • 聊聊Python都能做些什么
  • JavaWeb06-MVC和三层架构
  • MySQL数据库实现增删改查基础操作
  • PCM和I2S区别
  • 大模型笔记:吴恩达 ChatGPT Prompt Engineering for Developers(1) prompt的基本原则和策略
  • 设计模式 — — 单例模式
  • C++:菱形继承与虚继承
  • 贡献法:USACO 2021 December Contest Bronze:孤独的照片
  • Java实现简单的通讯录