当前位置: 首页 > news >正文

使用torch模拟 BMM int8量化计算。

在这里插入图片描述
使用torch模型BMM int8计算。
模拟:BMM->softmax->BMM 计算流程

import torch
import numpy as np
torch.manual_seed(777)
def int8_quantize_per_token(x: torch.Tensor, axis: int = -1, attns=False):if x.dtype != torch.float32:x = x.type(torch.float32)xmax = torch.abs(x)xmax = torch.max(xmax, dim=axis, keepdim=True)[0]scale = xmax / 127.0if not attns:# scale = torch.clamp(scale, 1e-5, np.finfo(np.float32).max)passelse:# scale = torch.tensor(1 / 127.0, dtype=torch.float32)passout = x / scaleout = torch.round(out)out = torch.clamp(out, -128, 127)quantized_out = out.type(torch.int8)return quantized_out, scaledef int8_quantize_per_tensor(x, axis=0, attns=False):if x.dtype != torch.float32:x = x.type(torch.float32)xmax = torch.abs(x)xmax = torch.max(xmax, dim=-1, keepdim=True)[0]xmax = torch.max(xmax, dim=-2, keepdim=True)[0]scale = xmax / 127.0if not attns:# scale = torch.clamp(scale, 1e-5, np.finfo(np.float32).max)passelse:# scale = torch.tensor(1 / 127.0, dtype=torch.float32)passout = x / scaleout = torch.round(out)out = torch.clamp(out, -128, 127)quantized_out = out.type(torch.int8)return quantized_out, scaledef matmul_int8(key, query, value):key = key.permute([0, 1, 3, 2])query, q_s = int8_quantize_per_token(query)key, k_s = int8_quantize_per_token(key, -2)attention_scores = torch.matmul(query.type(torch.float32),key.type(torch.float32))scale = q_s * k_sattention_1 = torch.mul(attention_scores, scale)attention_scores = attention_1 / torch.sqrt(torch.tensor(32, dtype=torch.float32))attention_scores = torch.softmax(attention_scores, dim=-1)attention_scores_int8, attn_p_s = int8_quantize_per_token(attention_scores, attns=True)value, v_s = int8_quantize_per_token(value, -2)context = torch.matmul(attention_scores_int8.type(torch.float32),value.type(torch.float32))scale = attn_p_s * v_scontext = torch.mul(context, scale)return attention_1, contextdef matmul_fp(key, query, value):key = key.permute([0, 1, 3, 2])attention_1 = torch.matmul(query.type(torch.float32),key.type(torch.float32))attention_scores = attention_1 / torch.sqrt(torch.tensor(32, dtype=torch.float32))attention_scores = torch.softmax(attention_scores, dim=-1)context = torch.matmul(attention_scores.type(torch.float32),value.type(torch.float32))return attention_1, contextdef mtx_similar1(arr1:np.ndarray, arr2:np.ndarray) ->float:'''计算矩阵相似度的一种方法。将矩阵展平成向量,计算向量的乘积除以模长。注意有展平操作。:param arr1:矩阵1:param arr2:矩阵2:return:实际是夹角的余弦值,ret = (cos+1)/2'''farr1 = arr1.ravel()farr2 = arr2.ravel()len1 = len(farr1)len2 = len(farr2)if len1 > len2:farr1 = farr1[:len2]else:farr2 = farr2[:len1]numer = np.sum(farr1 * farr2)denom = np.sqrt(np.sum(farr1**2) * np.sum(farr2**2))similar = numer / denom # 这实际是夹角的余弦值return  (similar+1) / 2     # 姑且把余弦函数当线性if __name__ == "__main__":key = torch.randn((2, 6, 10, 32))value = torch.randn((2, 6, 10, 32))query = torch.randn((2, 6, 1, 32))i_key = key.clone().detach()i_value = value.clone().detach()i_query = query.clone().detach()fp_score, fp_context = matmul_fp(key, query, value)int8_score, int8_context = matmul_int8(i_key, i_query, i_value)similar1 = mtx_similar1(int8_score.cpu().detach().numpy(),fp_score.cpu().detach().numpy())similar2 = mtx_similar1(int8_context.cpu().detach().numpy(),fp_context.cpu().detach().numpy())print(similar1, similar2)np.testing.assert_allclose(fp_score.detach().cpu().numpy(),int8_score.detach().cpu().numpy(),rtol=1e-02, atol=1e-03)np.testing.assert_allclose(fp_context.detach().cpu().numpy(),int8_context.detach().cpu().numpy(),rtol=1e-02, atol=1e-03)

结论:
Per-token 精度优于per-tensor
BMM1 和 BMM2定点计算之后,输出误差较大

http://www.lryc.cn/news/504185.html

相关文章:

  • 【FreeMarker】实现生成Controller根据模板勾选的内容查询
  • 深入理解 XPath:XML 和 HTML 文档的利器
  • DDR5 中的数据反馈判决均衡(DFE):全面解析与展望
  • Axure高保真数据可视化大屏图表组件库
  • 100个问题学 langchain 入门 (1/10)
  • 0001.基于springmvc简易酒店管理系统后台
  • 每日一题 326. 3 的幂
  • 解码数据有序之道——常见排序算法总结
  • C语言实现图片文件的复制
  • 一、windows上配置ninja环境
  • 我们来编程 -- win11多jdk版本切换
  • JAVA 图形界面编程 AWT篇(1)
  • C语言 字符串输入输出函数、scanf(“%[^\n]“,)可输入空格 、fgets删除换行符
  • 【蓝桥杯每日一题】推导部分和——带权并查集
  • Linux 磁盘满了怎么办?快速排查和清理方法
  • 【专题】2024年中国新能源汽车用车研究报告汇总PDF洞察(附原数据表)
  • 数据结构之链表笔试题详解
  • 结构化的Prompt
  • 【数字化】华为数字化转型架构蓝图
  • 最新全开源IM即时通讯系统源码(PC+WEB+IOS+Android)部署指南
  • go 跨平台打包
  • C++ 给定字符串,然后给出开始要取的位置,返回取到的信息
  • 【树莓派4B】MindSpore lite 部署demo
  • Idea汉化插件Datagrip汉化插件
  • 精彩回顾|Cocos开发者沙龙长沙站
  • 算法日记 49 day 图论(A*算法)
  • 服务器批量清理redis keys,无法适用客户端必须直连的情况
  • Grafana配置告警规则推送企微机器人服务器资源告警
  • 数字货币金融研究,深度学习虚拟币价格预测 数据集 市值top20 (2014年—2024年)
  • druid.properties图标是齿轮