当前位置: 首页 > news >正文

MLP感知机python实现

from torch import nn
from softmax回归 import train_ch3
import torch
import torchvision
from torch.utils import data
from torchvision import transforms# ①准备数据集
def load_data_fashion_mnist(batch_size, resize=None):# PyTorch中的一个转换函数,它的作用是将一个PIL Image或numpy.ndarray图像转换为一个Tensor数据类型。trans = [transforms.ToTensor()]# 是否需要改变大小if resize:trans.insert(0, transforms.Resize(resize))# 函数compose将这些转换操作组合起来trans = transforms.Compose(trans)# 训练数据mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)# 测试数据mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)# 返回值return (torch.utils.data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=4),torch.utils.data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=4))
# 批量大小为256
batch_size = 256
# 获取训练数据集和测试数据集
train_iter, test_iter = load_data_fashion_mnist(batch_size)# ②实现一个具有单隐藏层的多层感知机,它包含256个隐藏单元
# 定义输入,输出,隐藏层大小
num_inputs, num_outputs, num_hiddens = 784, 10, 256
# 定义W1、b1、W2、b2
W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))params = [W1, b1, W2, b2]# ③实现ReLU激活函数
def relu(X):a = torch.zeros_like(X)return torch.max(X, a)# ④实现模型
def net(X):# x=256*784X = X.reshape((-1, num_inputs))# torch.matmul(X,W1)= 256*256H = relu(torch.matmul(X,W1) + b1)# torch.matmul(H,W2) = 256*10return (torch.matmul(H,W2) + b2)# ⑤定义损失函数
loss = nn.CrossEntropyLoss()# ⑥训练
# 定义学习率
lr =  0.1
# 优化函数
updater = torch.optim.SGD(params, lr=lr)# 训练
if __name__ == '__main__':num_epochs = 10train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

训练结果

训练损失:0.0015049066459139188
训练精度:0.86405
测试精度:0.8453

貌似是比softmax十次好一些

http://www.lryc.cn/news/233955.html

相关文章:

  • Es 拼音搜索无法高亮
  • java线性并发编程介绍-锁(二)
  • Java JPA详解:从入门到精通
  • 使用Open3D库处理3D模型数据的实践指南
  • 代码随想录算法训练营第五十八天丨 动态规划part18
  • Pytest自动化测试框架介绍
  • 基于SpringBoot+Redis的前后端分离外卖项目-苍穹外卖(五)
  • Oracle 监控的指标有哪些和oracle巡检的内容
  • Uniapp有奖猜歌游戏系统源码 带流量主
  • 【算法与数据结构】前言
  • (六)什么是Vite——热更新时vite、webpack做了什么
  • 贝加莱MQTT功能
  • 基于JavaWeb+SSM+购物系统微信小程序的设计和实现
  • 为什么需要Code Review?
  • 【计算机网络笔记】ICMP(互联网控制报文协议)
  • Git教程1:生成和提交SSH公钥到远程仓库
  • 贝茄莱BR AS实时数据采集功能
  • Git的基本操作以及原理介绍
  • 2023安全与软工顶会/刊中区块链智能合约相关论文
  • word文档转换为ppt文件,怎么做?
  • 机器视觉选型-什么时候用远心镜头
  • quartz笔记
  • ER 图是什么
  • PLC电力载波通讯,一种新的IoT通讯技术
  • Elasticsearch:通过摄取管道加上嵌套向量对大型文档进行分块轻松地实现段落搜索
  • OpenCV图像纹理
  • 自媒体写手提问常用的ChatGPT通用提示词模板
  • 分类预测 | Matlab实现PSO-LSTM-Attention粒子群算法优化长短期记忆神经网络融合注意力机制多特征分类预测
  • 3GPP TS38.201 NR; Physical layer; General description (Release 18)
  • 【GitLab】-HTTP 500 curl 22 The requested URL returned error: 500~SSH解决