当前位置: 首页 > news >正文

TransnormerLLM 中 FlashLinearAttention 的纯pytorch实现

Github 仓库:https://github.com/One-sixth/flash-linear-attention-pytorch

flash-linear-attention-pytorch

纯 Pytorch 实现 TransnormerLLM 中快速线性注意力算子。
用于学习目的。
如果你希望用于训练模型,你可能要修改为 CUDA 或 Triton 的实现,不然会很慢。

注意

这个算子有精度问题,误差较大,是正常的。
这是因为注意力矩阵没有激活函数,导致注意力矩阵的值很大。
在使用 float16 类型时需要特别小心。

这是一个简单的缓解方法:限制 q 和 k 的值,从而减少float16溢出的可能性。

q = q / q.norm(-1, keepdim=True)
k = k / k.norm(-1, keepdim=True)
o = linear_attention(q, k, v, m)

使用方法

import torch
from flash_linear_attention_ops import flash_linear_attention, normal_linear_attentionbatch_size = 16
seq_len = 1024
dim = 64
n_head = 12
device = 'cuda'
dtype = torch.float32Q = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
K = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
V = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
M = torch.randint(0, 2, (1, 1, seq_len, seq_len), device=device, dtype=dtype)O_flash = flash_linear_attention(Q, K, V, M)
O_normal = normal_linear_attention(Q, K, V, M)print('O_flash.shape', O_flash.shape)
print('O_normal.shape', O_normal.shape)print('O diff', (O_flash - O_normal).abs().max().item())

参考引用

https://github.com/OpenNLPLab/TransnormerLLM
https://github.com/shreyansh26/FlashAttention-PyTorch

http://www.lryc.cn/news/112282.html

相关文章:

  • 从NPM注册中心获取包
  • Elastic的下载
  • day52-Redis
  • 高效处理矢量大数据的高可用解决方案
  • Docker Compose构建lnmp
  • Flutter开发问题记录
  • 如何使用本地mock数据
  • XXL-JOB定时任务框架(Oracle定制版)
  • SpringBoot + ajax 实现分页和增删查改
  • ProxyGenerator-代理类生成器
  • ARM 内存屏障指令
  • 了解Linux 的 mmap --- 笔记
  • docker删除容器(步骤详解)
  • boost beast http server 测试
  • Android 10.0 系统开启禁用adb push和adb pull传输文件功能
  • 浙大数据结构第七周之07-图4 哈利·波特的考试
  • vue2-vue项目中你是如何解决跨域的?
  • 【Paper Reading】DETR:End-to-End Object Detection with Transformers
  • 【rust/入门】windows安装rust gnu环境(折腾)
  • java面试---字符串相关内容
  • MYSQL进阶-事务的基础知识
  • 【C++】C++面向对象,泛型编程总结篇(封装,继承,多态,模板)|(秋招篇)
  • 【Github】作为程序员不得不知道的几款Github加速神器
  • react18之08自定义hook (简单的axios-get、修改浏览器title、localStorage、获取滚动条位置、img转换为base64)
  • 对CommonJS、AMD、CMD、ES Module的理解
  • JVM之类加载与字节码(二)
  • 安装linux操作系统
  • 【SpringBoot】知识
  • react ant add/change created_at
  • OSPF 动态路由协议 路由传递