当前位置: 首页 > news >正文

Llama中模块参数大小

LLama2中,流程中数据大小的变换如下

Transformer模块

第一次输入,进行prefill,输入x维度为[1, 8, 4096]

1. 构建wq,wk,wv,wo,尺寸均为[4096,4096], 与x点乘,得到xq, xk, xv

2. 构建KV cache, 尺寸为 [batch size, max_seq_len, local_kv_heads, head_dim],对应 [1, 8, 32, 128]

3.基于kv  cache构造 keys, alues,对应的尺寸还是[1,8,32,128]

4. 在最后两个维度对于xq和key进行点乘,得到scores,维度变成【1, 32, 8, 8】

5. 将mask与scores相加

6. 对于scores进行softmax

7. 将scores [1, 32, 8, 8]与values [1, 32, 8, 128]进行乘法

8. 得到output [1, 8, 4096]

9. 将output再与wo进行乘法[1, 8, 4096]

10. 接下来对于输出进行 ffn_norm的操作

Feedforward模块

11.然后进行feed_forward.得到当前transformer模块的输出 [1, 8, 4096]

feed_forward的操作如下,虽然代码很小,但是计算量却很大。

    def forward(self, x):return self.w2(F.silu(self.w1(x)) * self.w3(x))

其中,w1的维度为[11008, 4096], w2的维度为[4096, 11008], w3的维度为[11008, 4096]

kv cache的表达如下

        self.cache_k = torch.zeros((args.max_batch_size,args.max_seq_len,self.n_local_kv_heads,self.head_dim,)).cuda()self.cache_v = torch.zeros((args.max_batch_size,args.max_seq_len,self.n_local_kv_heads,self.head_dim,)

关于kv cache的细节讨论

llama2设定 local_kv_heads为32,head_dim为128。所以,kv cache的尺寸为 [1, 512,32, 128] * 2

对于一个batch的数据来说哦,因为llama2 7B 包含32个transformer,所以,当使用FP32表达时, 对应一个batch的kv cache的大小为128 * 32 * 128 *2 * 32 * 4byte= 0.5GB.

这里,也可以看到几个变量:

* 当batch变大时,kv cache线性增长

* 当batch 的最大长度增大时, Kv cache线性增长。

参考链接:

https://arxiv.org/pdf/1911.02150

http://www.lryc.cn/news/406367.html

相关文章:

  • Modbus转EtherCAT网关将Modbus协议的数据格式转换为EtherCAT协议
  • 【开发实战】QT5 + OpenCV4 开发环境配置应用演示
  • “微软蓝屏”事件暴露的网络安全问题及应对策略
  • 白骑士的PyCharm教学基础篇 1.3 调试与运行
  • 爬虫学习1:初学者简单了解爬虫的基本认识和操作(详细参考图片)
  • WHAT - 通过 shadcn 组件源码学习 React
  • grafana对接zabbix数据展示
  • C++ 学习补充 1:短链算法
  • 硅纪元视角 | 语音克隆突破:微软VALL-E 2,Deepfake新纪元!
  • 没有51基础,能不能学好STM32?
  • Web开发:VUE3小白开发入门基础笔记
  • 技术周总结 2024.07.15~07.21周日(Spark性能优化)
  • 提高性能的常见技术
  • LeetCode206 反转链表
  • nginx通过nginx_upstream_check_module实现后端健康检查
  • FastGPT 知识库搜索测试功能解析(二)
  • 双向链表<数据结构 C版>
  • react18+
  • rk3568 OpenHarmony4.1 Launcher定制开发—桌面壁纸替换
  • MySQL:送分or送命 varchar(30) 与 int(10)
  • 【odoo17】后端py方法触发右上角提示组件
  • 1775D - Friendly Spiders
  • 【python】OpenCV—Point Polygon Test
  • 6 Go语言的常量、枚举、作用域
  • 第十一章 数据结构
  • LeetCode704 二分查找
  • [言简意赅] Matlab生成FPGA端rom初始化文件.coe
  • 【QAC】分布式部署下其他机器如何连接RLM
  • 从等保测评看行业安全趋势:洞察与预测
  • HTTP模块(二)