当前位置: 首页 > news >正文

大模型笔记:pytorch实现MOE

0 导入库

import torch
import torch.nn as nn
import torch.nn.functional as F

1  专家模型

#一个简单的专家模型,可以是任何神经网络架构
class Expert(nn.Module):def __init__(self, input_size, output_size):super(Expert, self).__init__()self.fc = nn.Linear(input_size, output_size)def forward(self, x):return self.fc(x)

2 MOE

class MoE(nn.Module):def __init__(self, num_experts, input_size, output_size,topk):super(MoE, self).__init__()self.num_experts = num_experts self.topk=topkself.experts = nn.ModuleList([Expert(input_size, output_size) for _ in range(num_experts)])#创建多个专家self.gating_network = nn.Linear(input_size, num_experts)  #门控网络def forward(self, x):#假设x的维度是(batch,input_size)gating_scores = self.gating_network(x)# 门控网络决定权重 (选择每一个专家的概率)#输出维度是(batch_size, num_experts)topk_gate_scores,topk_gate_index=gating_scores.topk(topk,-1)#选取topk个专家#(batch_size,topk)gating_scores_filtered=torch.full_like(gating_scores,fill_value=float("-inf"))gating_scores_filtered=gating_scores_filtered.scatter(-1,topk_gate_index,topk_gate_scores)gating_scores_filtered=F.softmax(gating_scores_filtered,dim=-1)##创建一个全为负无穷的张量 zeros,并将 topk_gate_scores 的值插入到这个张量的对应位置#(batch_size,num_experts)expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)  # 专家网络输出#每个expert的输出维度是 (batch_size, output_size)#stack沿着第二个维度堆叠,之后expert_outputs的维度是(batch_size,num_experts,output_size)moe_output = torch.bmm(gating_scores_filtered.unsqueeze(1), expert_outputs).squeeze(1) # 加权组合专家输出 #gating_scores.unsqueeze(1)——>(batch_size, 1,num_experts)#torch.bmm(gating_scores.unsqueeze(1), expert_outputs)——>(batch_size,1,output_size)#moe_output——>(batch_size,output_size)return moe_output

3  输入举例


input_size = 10  
# 输入特征是大小为10的向量
output_size = 5  
# 输出大小为5的向量
num_experts = 3  
# 3个专家moe_model = MoE(num_experts, input_size, output_size)
# 初始化MOE模型input_vector = torch.randn(1, input_size)
# 创建一个输入向量output_vector = moe_model(input_vector)
# 前向传递print(output_vector.shape,output_vector)
# 打印输出
'''
torch.Size([1, 5]) tensor([[ 2.7343e-04,  4.0966e-01, -3.6634e-01, -8.9064e-01,  4.0759e-01]],grad_fn=<SqueezeBackward1>)
'''

http://www.lryc.cn/news/535540.html

相关文章:

  • HAL库USART中断接收的相关问题
  • @Transational事务注解底层原理以及什么场景事务会失效
  • Linux扩容磁盘
  • 全面解析鸿蒙(HarmonyOS)开发:从入门到实战,构建万物互联新时代
  • Uniapp 原生组件层级过高问题及解决方案
  • Android adb测试常用命令大全
  • linux的基础入门2
  • 19.4.8 数据库综合运用
  • JAVA中的抽象学习
  • 在 Go 中实现事件溯源:构建高效且可扩展的系统
  • 加解密 | AES加、解密学习
  • 【学术投稿-2025年计算机视觉研究进展与应用国际学术会议 (ACVRA 2025)】CSS样式解析:行内、内部与外部样式的区别与优先级分析
  • MongoDB 基本操作
  • Eclipse JSP/Servlet 深入解析
  • Hyperledger caliper 性能测试
  • Record-Mode 备案免关站插件,让 WordPress 备案不影响 SEO 和收录
  • 【Java 面试 八股文】Redis篇
  • 介绍几款免费的显示器辅助工具!
  • django配置跨域
  • web前端第三次作业
  • 【Pandas】pandas Series align
  • DeepSeek-V3网络模型架构图解
  • Linux系统管理小课堂
  • 明远智睿核心板在智能家居与工业网关中的应用实践
  • Windows 系统 GDAL库 配置到 Qt 上
  • 部署onlyoffice后,php版的callback及小魔改(logo和关于)
  • 《qt open3d网格拉普拉斯平滑》
  • 【愚公系列】《Python网络爬虫从入门到精通》004-请求模块urllib3
  • 网络安全技术复习总结
  • 初阶c语言(while循环二分法)