当前位置: 首页 > news >正文

大模型推理——MLA实现方案

1.整体流程

先上一张图来整体理解下MLA的计算过程

2.实现代码

import math
import torch
import torch.nn as nn# rms归一化
class RMSNorm(nn.Module):""""""def __init__(self, hidden_size, eps=1e-6):super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):hidden_states = hidden_states.float()variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.float()def rotate_half(x):x1, x2 = x.chunk(2, dim=-1)return torch.cat((-x2, x1), dim=-1)def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):cos = cos.unsqueeze(unsqueeze_dim)sin = sin.unsqueeze(unsqueeze_dim)q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed# 旋转位置编码
class RotaryEmbedding(nn.Module):def __init__(self, dim, max_seq_len=1024):super(RotaryEmbedding, self).__init__()self.dim = dimself.max_seq_len = max_seq_leninv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))t = torch.arange(max_seq_len).float().unsqueeze(1)freqs = t @ inv_freq.unsqueeze(0)freqs = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", freqs.cos())self.register_buffer("sin_cached", freqs.sin())def forward(self, q, k):cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)return apply_rotate_pos_emb(q, k, cos, sin)class MLA(nn.Module):def __init__(self,dim,n_heads,q_lora_rank,kv_lora_rank,qk_nope_head_dim,qk_rope_head_dim,v_head_dim,max_seq_len,max_batch_size,mode):super().__init__()self.dim = dim  # 隐藏层维度self.n_heads = n_heads  # 总头数self.q_lora_rank = q_lora_rank  # q低秩压缩到的维度self.kv_lora_rank = kv_lora_rank  # k/v低秩压缩到的维度self.qk_nope_head_dim = qk_nope_head_dim    # q/k不带旋转位置编码的维度self.qk_rope_head_dim = qk_rope_head_dim    # q/k带旋转位置编码的维度self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim  # q/k的总维度,不带旋转位置编码的维度加上带旋转位置编码的维度self.v_head_dim = v_head_dim  # value的维度,等于不带旋转位置编码的k维度self.mode = modeself.max_seq_len = max_seq_lenself.max_batch_size = max_batch_sizeself.wq_a = nn.Linear(self.dim, self.q_lora_rank)  # q的降维矩阵self.q_norm = RMSNorm(self.q_lora_rank)self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)  # q的升维矩阵# 4096*128+128*4864 = 524,288 + 622592 = 1146880    4096*4864 = 19,922,944self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)  # k/v的降维矩阵# nn.Linear(self.dim, self.kv_lora_rank)# nn.Linear(self.dim, self.qk_rope_head_dim)self.kv_norm = RMSNorm(self.kv_lora_rank)self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))  # k/v的升维矩阵self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)  # 旋转位置编码# 没有矩阵融合if self.mode == 'naive':self.register_buffer('k_cache',torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.qk_head_dim),persistent=False)self.register_buffer('v_cache',torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.v_head_dim),persistent=False)# 有矩阵融合else:self.register_buffer('kv_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank),persistent=False)self.register_buffer('pe_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim),persistent=False)def forward(self, x, mask=None):bs, seq_len, _ = x.shapeq = self.wq_a(x)  # [bs, seq_len, q_lora_rank]q = self.q_norm(q)  # [bs, seq_len, q_lora_rank]q = self.wq_b(q)  # [bs, seq_len, n_heads * qk_head_dim]q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim)  # [bs, seq_len, n_heads, qk_head_dim]q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim],dim=-1)  # q_nope shape:[bs, seq_len, n_heads, qk_nope_head_dim] q_pe shape:[bs, seq_len, n_heads, qk_rope_head_dim]kv = self.wkv_a(x)  # [bs, seq_len, kv_lora_rank + qk_rope_head_dim]kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim],dim=-1)  # kv shape:[bs, seq_len, kv_lora_rank] k_pe shape:[bs, seq_len, qk_rope_head_dim]k_pe = k_pe.unsqueeze(2)  # k_pe shape:[bs, seq_len, 1, qk_rope_head_dim]   一层共享一个keyq_pe, k_pe = self.rotary_emb(q_pe, k_pe)if self.mode == 'naive':q = torch.cat([q_nope, q_pe], dim=-1)  # * [bs, seq_len, n_heads, qk_head_dim]kv = self.kv_norm(kv)  # [bs, seq_len, kv_lora_rank)]kv = self.wkv_b(kv)  # [bs, seq_len, n_heads * (qk_nope_head_dim + v_head_dim)]kv = kv.view(bs, seq_len, self.n_heads, self.qk_nope_head_dim + self.v_head_dim)k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1)# k shape:[bs, seq_len, n_heads, qk_head_dim]self.k_cache[:bs, :seq_len, :, :] = kself.v_cache[:bs, :seq_len, :, :] = v# scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bs, :seq_len]) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)scores = torch.matmul(q.transpose(1, 2),self.k_cache[:bs, :seq_len, :, :].transpose(1, 2).transpose(2, 3) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim))scores = scores.transpose(1, 2)else:k_pe = k_pe.squeeze(2)wkv_b = self.wkv_b.weight  # [n_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]wkv_b = wkv_b.view(self.n_heads, -1,self.kv_lora_rank)  # [n_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]q_nope = torch.einsum("bshd,hdc->bshc", q_nope,wkv_b[:, :self.qk_nope_head_dim])  # q_nope shape:[bs, seq_len, n_heads, kv_lora_rank]# q*k(T) = x*wq*(c*wkv_b[:, :self.qk_nope_head_dim])(T) = x*wq*wkv_b[:, :self.qk_nope_head_dim](T)*c(T)    c为压缩后的k/v# wq*wkv_b[:, :self.qk_nope_head_dim](T)作为q的投影矩阵  c可以替代原先的k,这样就可以直接使用压缩后的k/v计算注意力了,kv_cache时也只需存储压缩后的k/vkv = self.kv_norm(kv)self.kv_cache[:bs, :seq_len, :] = kv  # kv shape:[bs, seq_len, kv_lora_rank]self.pe_cache[:bs, :seq_len, :] = k_pe  # k_pe shape:[bs, seq_len, qk_rope_head_dim]scores_nope = torch.einsum("bshc,btc->bsht", q_nope,self.kv_cache[:bs, :seq_len, :])  # bshc btc -> bshc bct -> bshtscores_pe = torch.einsum("bshr,btr->bsht", q_pe,self.pe_cache[:bs, :seq_len, :])  # bshr btr -> bshr bt1r -> bshr bthr -> bshtscores = (scores_nope + scores_pe) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)  # [bs, seq_len, n_heads, seq_len]if mask is not None:# mask shape:[bs, seq_len, seq_len]scores += mask.unsqueeze(2)scores = scores.softmax(dim=-1)if self.mode == 'naive':x = torch.einsum("bsht,bthd->bshd", scores,self.v_cache[:bs, :seq_len])  # bsht,bthd -> bhst, bhtd -> bhsd -> bshdelse:# scores * v = scores * c * wkv_b[:, -self.v_head_dim:]x = torch.einsum("bsht,btc->bshc", scores,self.kv_cache[:bs, :seq_len])  # x shape:[bs, seq_len, n_heads, kv_lora_rank]x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])  # bshc, hdc -> bshc,dch -> bsdh -> bshdx = x.contiguous().view(bs, seq_len, -1)x = self.wo(x) return xif __name__ == '__main__':torch.manual_seed(0)torch.set_printoptions(precision=3, sci_mode=False)x = torch.randn(1, 4, 16)dim = 16n_heads = 2q_lora_rank = 10kv_lora_rank = 6qk_nope_head_dim = 8qk_rope_head_dim = 4v_head_dim = 8max_seq_len = 10max_batch_size = 4mode = 'none'mla = MLA(dim=dim,n_heads=n_heads,q_lora_rank=q_lora_rank,kv_lora_rank=kv_lora_rank,qk_nope_head_dim=qk_nope_head_dim,qk_rope_head_dim=qk_rope_head_dim,v_head_dim=v_head_dim,max_seq_len=max_seq_len,max_batch_size=max_batch_size,mode=mode)print(mla(x))print(mla.kv_cache)

参考资料:

https://zhuanlan.zhihu.com/p/16730036197

https://github.com/wyf3/llm_related/tree/main/deepseek_learn

http://www.lryc.cn/news/533387.html

相关文章:

  • redis之GEO 模块
  • 21.2.7 综合示例
  • 使用Docker + Ollama在Ubuntu中部署deepseek
  • 【C语言标准库函数】三角函数
  • CNN-day9-经典神经网络ResNet
  • 淘宝分类详情数据获取:Python爬虫的高效实现
  • 机器学习 —— 深入剖析线性回归模型
  • 33.日常算法
  • #渗透测试#批量漏洞挖掘#微商城系统 goods SQL注入漏洞
  • 【翻译+论文阅读】DeepSeek-R1评测:粉碎GPT-4和Claude 3.5的开源AI革命
  • Vision Transformer学习笔记(2020 ICLR)
  • 一步一步生成音乐类小程序的详细指南,结合AI辅助开发的思路
  • 25/2/8 <机器人基础> 阻抗控制
  • golang 开启HTTP代理认证
  • 详解Nginx no live upstreams while connecting to upstream
  • Open3d Qt的环境配置
  • 5.Python字典和元组:字典的增删改查、字典遍历、访问元组、修改元组、集合(set)
  • 深度学习系列--04.梯度下降以及其他优化器
  • 2022java面试总结,1000道(集合+JVM+并发编程+Spring+Mybatis)的Java高频面试题
  • Ubuntu MKL(Intel Math Kernel Library)
  • 消费电子产品中的噪声对TPS54202的影响
  • 第四十章:职场转折:突破困境,重新出发
  • c++ 不定参数,不定类型的 max,min 函数
  • 数据库的关系代数
  • VSCode使用总结
  • 关系模型的数据结构及形式化定义
  • 【C++入门讲解】
  • 数据表中的视图操作
  • BFS算法篇——广度优先搜索,探索未知的旅程(上)
  • mongodb 使用内存过大分析