当前位置: 首页 > news >正文

AI学习记录 - 最简单的专家模型 MOE

代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tupleclass BasicExpert(nn.Module):# 一个 Expert 可以是一个最简单的, linear 层即可# 也可以是 MLP 层# 也可以是 更复杂的 MLP 层(active function 设置为 swiglu)def __init__(self, feature_in, feature_out):super().__init__()self.linear = nn.Linear(feature_in, feature_out)def forward(self, x):return self.linear(x)class BasicMOE(nn.Module):# 创建了一个 BasicMOE 模型,输入特征维度为 6, 输出特征维度为 3, 专家数量为 2。def __init__(self, feature_in, feature_out, expert_number):super().__init__()self.experts = nn.ModuleList([BasicExpert(feature_in, feature_out) for _ in range(expert_number)])# gate 就是选一个 expert self.gate = nn.Linear(feature_in, expert_number)def forward(self, x):# 两个专家数量, expert_weight 就是两个数字expert_weight = self.gate(x)  # shape 是 (batch, expert_number)print("expert_weight", expert_weight)expert_out_list = [expert(x).unsqueeze(1) for expert in self.experts]  # 里面每一个元素的 shape 是: (batch, ) ??# concat 起来 (batch, expert_number, feature_out)# 每个专家输出的特征是3个维度expert_output = torch.cat(expert_out_list, dim=1)print("expert_output.size()", expert_output.size())print("expert_weight", expert_weight.size())expert_weight = expert_weight.unsqueeze(1) # (batch, 1, expert_nuber)print("expert_weight", expert_weight.size())# expert_weight * expert_out_listoutput = expert_weight @ expert_output  # (batch, 1, feature_out)return output.squeeze()def test_basic_moe():x = torch.rand(2, 6)# x  是一个形状为  (2, 6)  的输入张量 (2 个样本, 每个样本 6 个特征)。# 创建了一个 BasicMOE 模型,输入特征维度为 6, 输出特征维度为 3, 专家数量为 2。basic_moe = BasicMOE(6, 3, 2)out = basic_moe(x)# 表示 2 个样本,2 个专家,每个专家输出 3 个特征。print(out)test_basic_moe()

代码对应的配图解释:
在这里插入图片描述

http://www.lryc.cn/news/536076.html

相关文章:

  • 急停信号的含义
  • 单调队列queue
  • 【漫话机器学习系列】091.置信区间(Confidence Intervals)
  • UnicodeDecodeError: ‘gbk‘ codec can‘t decode byte 0x99
  • DeepSeek应用——与word的配套使用
  • 递归乘法算法
  • 【免费】2004-2020年各省废气中废气中二氧化硫排放量数据
  • CNN-LSSVM卷积神经网络最小二乘支持向量机多变量多步预测,光伏功率预测
  • 【油猴脚本/Tampermonkey】DeepSeek 服务器繁忙无限重试(20250213优化)
  • 单调栈及相关题解
  • 每日温度问题:如何高效解决?
  • #渗透测试#批量漏洞挖掘#致远互联AnalyticsCloud 分析云 任意文件读取
  • 统计安卓帧率和内存
  • 大数据学习之PB级百战出行网约车二
  • C语言第18节:自定义类型——联合和枚举
  • C++病毒(^_^|)(2)
  • 在vscode中拉取gitee里的项目并运行
  • centos7 防火墙开放指定端口
  • Day42(补)【AI思考】-编译过程中语法分析及递归子程序分析法的系统性解析
  • AI成为基础设施有哪些研究方向:模型的性能、可解释性,算法偏见
  • 写一个鼠标拖尾特效
  • Redisson介绍和入门使用
  • OpenAI推出全新AI助手“Operator”:让人工智能帮你做事的新时代!
  • Python----PyQt开发(PyQt基础,环境搭建,Pycharm中PyQttools工具配置,第一个PyQt程序)
  • 算法笔记 02 —— 入门模拟
  • PyTorch 源码学习:从 Tensor 到 Storage
  • uniapp 使用 鸿蒙开源字体
  • LabVIEW多电机CANopen同步
  • 每日定投40刀BTC(2)20250209 - 20250212
  • 【LeetCode Hot100 子串】和为 k 的子数组、滑动窗口最大值、最小覆盖子串