当前位置: 首页 > news >正文

4种feature classification在代码的实现上是怎么样的?Linear / MLP / CNN / Attention-Based Heads

具体的分类效果可以看:【Arxiv 2023】Diffusion Models Beat GANs on Image Classification


1、线性分类器 (Linear, A)

使用一个简单的线性层,通常与一个激活函数结合使用。

import torch.nn as nnclass LinearClassifier(nn.Module):def __init__(self, input_size, num_classes):super(LinearClassifier, self).__init__()self.linear = nn.Linear(input_size, num_classes)def forward(self, x):return self.linear(x)

2、多层感知机 (Multi-Layer Perceptron, B)

包括多个线性层,每层之间可能有激活函数和dropout层。

class MLPClassifier(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(MLPClassifier, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes)def forward(self, x):x = self.relu(self.fc1(x))x = self.fc2(x)return x

3、卷积神经网络 (Convolutional Neural Network, CNN, C)

使用一系列卷积层,通常包括池化层和全连接层。

class CNNClassifier(nn.Module):def __init__(self, num_classes):super(CNNClassifier, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1)self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)self.conv2 = nn.Conv2d(32, 64, 3, 1, 1)self.fc = nn.Linear(64 * 7 * 7, num_classes)  # Assuming input size is 28x28def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(x.size(0), -1)  # Flatten the tensorx = self.fc(x)return x

4、基于注意力机制的头部 (Attention-Based Heads, D)

使用注意力机制,如Transformer的头部结构。

from torch.nn import TransformerEncoder, TransformerEncoderLayerclass AttentionClassifier(nn.Module):def __init__(self, input_size, num_classes, nhead, nhid, nlayers):super(AttentionClassifier, self).__init__()self.model_type = 'Transformer'self.encoder_layer = TransformerEncoderLayer(d_model=input_size, nhead=nhead, dim_feedforward=nhid)self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=nlayers)self.decoder = nn.Linear(input_size, num_classes)def forward(self, src):output = self.transformer_encoder(src)output = self.decoder(output.mean(dim=1))return output

http://www.lryc.cn/news/266513.html

相关文章:

  • 最新Unity DOTS Physics物理引擎碰撞事件处理
  • springboot集成websocket全全全!!!
  • SpringMVC:整合 SSM 中篇
  • oracle即时客户端(Instant Client)安装与配置
  • POP3协议详解
  • 电子病历编辑器源码,提供电子病历在线制作、管理和使用的一体化电子病历解决方案
  • WT2605C高品质音频蓝牙语音芯片:外接功放实现双声道DAC输出的优势
  • IntelliJ IDEA 2023.3 最新版如何如何配置?IntelliJ IDEA 2023.3 最新版试用方法
  • 如何查看内存卡使用记录-查看的设备有:U盘、移动硬盘、MP3、SD卡等-供大家学习研究参考
  • 九、W5100S/W5500+RP2040之MicroPython开发<HTTPOneNET示例>
  • 在 Laravel 中,清空缓存大全
  • 【贪心】单源最短路径Python实现
  • Spark Shell的简单使用
  • Springsecurty【2】认证连接MySQL
  • .Net 访问电子邮箱-LumiSoft.Net,好用
  • 谷粒商城-商品服务-新增商品功能开发(商品图片无法展示问题没有解决)
  • Open3D 点云数据处理基础(Python版)
  • 使用vue-qr,报错in ./node_modules/vue-qr/dist/vue-qr.js
  • 百川2大模型微调问题解决
  • MySQL的事务-原子性
  • D3839|完全背包
  • Java之Synchronized与锁升级
  • kitex出现:open conf/test/conf.yaml: no such file or directory
  • sql server多表查询
  • 如何利用PPT绘图并导出清晰图片
  • 1.倒排索引 2.逻辑斯提回归算法
  • Kafka消费者组
  • 四. 基于环视Camera的BEV感知算法-BEVDepth
  • CentOS系统环境搭建(二十五)——使用docker compose安装mysql
  • 协作机器人(Collaborative-Robot)安全碰撞的速度与接触力