当前位置: 首页 > news >正文

LLaMA-Adapter源码解析

LLaMA-Adapter源码解析

伪代码

def transformer_block_with_llama_adapter(x, gating_factor, soft_prompt):residual =xy= zero_init_attention(soft_prompt, x) # llama-adapter: prepend prefixx= self_attention(x)x = x+ gating_factor * y  # llama-adapter: apply zero_init_attentionx = LayerNorm(x+residual)residual = xx = FullyConnectedLayers(x)x = AdapterLayers(x)x = LayerNorm(x + residual)return x

源码

class Attention(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()self.head_dim = args.dim // args.n_headsself.wq = ColumnParallelLinear(args.dim,args.n_heads * self.head_dim,bias=False,gather_output=False,init_method=lambda x: x,)self.wk = ColumnParallelLinear(args.dim,args.n_heads * self.head_dim,bias=False,gather_output=False,init_method=lambda x: x,)self.wv = ColumnParallelLinear(args.dim,args.n_heads * self.head_dim,bias=False,gather_output=False,init_method=lambda x: x,)self.wo = RowParallelLinear(args.n_heads * self.head_dim,args.dim,bias=False,input_is_parallel=True,init_method=lambda x: x,)self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda()self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda()self.gate = torch.nn.Parameter(torch.zeros(1))def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):bsz, seqlen, _ = x.shapexq, xk, xv = self.wq(x), self.wk(x), self.wv(x)xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)self.cache_k = self.cache_k.to(xq)self.cache_v = self.cache_v.to(xq)self.cache_k[:bsz, start_pos : start_pos + seqlen] = xkself.cache_v[:bsz, start_pos : start_pos + seqlen] = xvkeys = self.cache_k[:bsz, : start_pos + seqlen]values = self.cache_v[:bsz, : start_pos + seqlen]if adapter is not None:adapter_len = adapter.shape[1]adapter_k = self.wk(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)adapter_v = self.wv(adapter).view(1, adapter_len, self.n_local_heads, self.head_dim).repeat(bsz, 1, 1, 1)adapter_k = adapter_k.transpose(1, 2)adapter_v = adapter_v.transpose(1, 2)xq = xq.transpose(1, 2)keys = keys.transpose(1, 2)values = values.transpose(1, 2)scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)if mask is not None:scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)scores = F.softmax(scores.float(), dim=-1).type_as(xq)output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)if adapter is not None:adapter_scores = torch.matmul(xq, adapter_k.transpose(2, 3)) / math.sqrt(self.head_dim)adapter_scores = self.gate * F.softmax(adapter_scores.float(), dim=-1).type_as(xq)output = output + torch.matmul(adapter_scores, adapter_v)output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)return self.wo(output)
http://www.lryc.cn/news/215923.html

相关文章:

  • JavaScript设计模式之发布-订阅模式
  • mysql---索引
  • 微信小程序——简易复制文本
  • 【51单片机】矩阵键盘与定时器(学习笔记)
  • vue 中使用async await
  • C语言学习之内存区域的划分
  • Unity Animator cpu性能测试
  • 数据结构 - 顺序表ArrayList
  • 【Echarts】玫瑰饼图数据交互
  • k8s、pod
  • 一天掌握python爬虫【基础篇】 涵盖 requests、beautifulsoup、selenium
  • 睿趣科技:想知道开抖音小店的成本
  • python项目部署代码汇总:目标检测类、人体姿态类
  • 力扣每日一题92:反转链表||
  • Vue+OpenLayers从入门到实战进阶案例汇总目录,兼容OpenLayers7和OpenLayers8
  • C#中使用LINQtoSQL管理SQL数据库之添加、修改和删除
  • 飞致云及其旗下1Panel项目进入2023年第三季度最具成长性开源初创榜单
  • Maven实战-私服搭建详细教程
  • uniapp-自定义表格,右边操作栏固定
  • 基于Electron27+React18+ArcoDesign客户端后台管理EXE
  • QT5交叉编译保姆级教程(arm64、mips64)
  • python计算图片的RGB值,可以作为颜色的判断条件
  • oracle 日期
  • JVM堆内存解析
  • C#Onnx模型信息查看工具
  • RK3588 ubuntu系统安装opencv
  • 常用的vue UI组件库
  • 防范欺诈GPT
  • 【Java】多线程案例(单例模式,阻塞队列,定时器,线程池)
  • STM32:使用蓝牙模块