AI学习记录 - 最简单的专家模型 MOE
代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tupleclass BasicExpert(nn.Module):# 一个 Expert 可以是一个最简单的, linear 层即可# 也可以是 MLP 层# 也可以是 更复杂的 MLP 层(active function 设置为 swiglu)def __init__(self, feature_in, feature_out):super().__init__()self.linear = nn.Linear(feature_in, feature_out)def forward(self, x):return self.linear(x)class BasicMOE(nn.Module):# 创建了一个 BasicMOE 模型,输入特征维度为 6, 输出特征维度为 3, 专家数量为 2。def __init__(self, feature_in, feature_out, expert_number):super().__init__()self.experts = nn.ModuleList([BasicExpert(feature_in, feature_out) for _ in range(expert_number)])# gate 就是选一个 expert self.gate = nn.Linear(feature_in, expert_number)def forward(self, x):# 两个专家数量, expert_weight 就是两个数字expert_weight = self.gate(x) # shape 是 (batch, expert_number)print("expert_weight", expert_weight)expert_out_list = [expert(x).unsqueeze(1) for expert in self.experts] # 里面每一个元素的 shape 是: (batch, ) ??# concat 起来 (batch, expert_number, feature_out)# 每个专家输出的特征是3个维度expert_output = torch.cat(expert_out_list, dim=1)print("expert_output.size()", expert_output.size())print("expert_weight", expert_weight.size())expert_weight = expert_weight.unsqueeze(1) # (batch, 1, expert_nuber)print("expert_weight", expert_weight.size())# expert_weight * expert_out_listoutput = expert_weight @ expert_output # (batch, 1, feature_out)return output.squeeze()def test_basic_moe():x = torch.rand(2, 6)# x 是一个形状为 (2, 6) 的输入张量 (2 个样本, 每个样本 6 个特征)。# 创建了一个 BasicMOE 模型,输入特征维度为 6, 输出特征维度为 3, 专家数量为 2。basic_moe = BasicMOE(6, 3, 2)out = basic_moe(x)# 表示 2 个样本,2 个专家,每个专家输出 3 个特征。print(out)test_basic_moe()
代码对应的配图解释: