当前位置: 首页 > news >正文

机器学习深度学习——多层感知机的从零开始实现

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——多层感知机
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

为了与之前的softmax回归获得的结果进行比较,将继续使用Fashion-MNIST图像分类数据集。

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

多层感知机的从零开始实现

  • 初始化模型参数
  • 激活函数
  • 模型
  • 损失函数
  • 训练
  • 预测

初始化模型参数

数据集的每个图像由28×28=784个灰度像素值组成。所有图像分为10个类别。
忽略像素间的空间结构,我们可以将每个图像视为具有784个输入特征和10个类的简单分类数据集。
首先,我们将实现一个具有单隐藏层的多层感知机,它包含256个隐藏单元。注意,我们可以将这两个变量都视为超参数。通常,我们选择2的若干次幂作为层的宽度。因为内存在硬件的分配和寻址方式,这么做往往可以在计算上更高效。
我们用几个张量来表示我们的参数。注意,对于每一层我们都要记录一个权重矩阵和一个偏置向量。并要为这些参数的梯度分配内存。

num_inputs, num_outputs, num_hiddens = 784, 10, 256
W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))
params = [W1, b1, W2, b2]

激活函数

这里就不用内置的了,自己实现一下:

def relu(X):a = torch.zeros_like(X)return torch.max(X, a)

模型

既然忽略了空间结构,那就直接用reshape将每个二维图像转换为一个长度为num_inputs的向量:

def net(X):X = X.reshape((-1, num_inputs))H = relu(X@W1 + b1)  # "@"表示矩阵乘法return (H@W2 + b2)

损失函数

之前已经从零实现过了softmax函数,这里直接用内置函数计算softmax和交叉熵损失(为什么要计算这两个,之前在softmax的简洁实现中曾经证明过)

loss = nn.CrossEntropyLoss(reduction='none')

训练

训练过程和softmax一样,直接调用d2l的train_ch3函数就行了,将迭代周期数设为10,学习率设为0.1。

num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

预测

对模型进行评估,我们在测试数据上应用这个模型。

d2l.predict_ch3(net, test_iter)
d2l.plt.show()

在这里插入图片描述

http://www.lryc.cn/news/103186.html

相关文章:

  • Redis的基本使用命令
  • Ts入门到放弃
  • 黑客技术(网络安全)学习笔记
  • Cloud Kernel SIG 月度动态:支持龙芯和申威架构,合入两个内存新特性
  • IDEA中连接虚拟机 管理Docker
  • Debezium日常分享系列之:定制Debezium 信号发送和通知
  • RpcProvider(rpc服务提供者)实现思路
  • GNSS技术知识你知道多少?这些你或许还未掌握
  • YOLOv8教程系列:三、使用YOLOv8模型进行自定义数据集半自动标注
  • AI聊天GPT三步上篮!
  • 如何彻底卸载VMware
  • [个人笔记] Windows配置NTP时间同步
  • Jetson Docker 编译 FFmpeg 支持硬解nvmpi和cuvid
  • 某某某小说app接口抓包分析
  • 开发一个RISC-V上的操作系统(四)—— 内存管理
  • 区块链:可验证随机函数
  • Flask中flask-session
  • react-Native init初始化项目报错”TypeError: cli.init is not a function“
  • 【gitlib】linux系统rpm安装gitlib最新版本
  • iOS开发-检查版本更新与强制更新控制
  • 自动化运维工具——Ansible
  • W2NER详解
  • ElementUI tabs标签页样式改造美化
  • 出海周报|Temu在美状告shein、ChatGPT安卓版上线、小红书回应闪退
  • 2023年7月26日 单例模式
  • [ 容器 ] Docker 安全及日志管理
  • 游游的排列构造
  • 拯救者Y9000K无线Wi-Fi有时不稳定?该如何解决?
  • 【业务功能篇59】Springboot + Spring Security 权限管理 【下篇】
  • 性能优化 - 前端性能监控和性能指标计算方式