当前位置: 首页 > news >正文

【机器学习】034_多层感知机Part.2_从零实现多层感知机

一、解决XOR问题

1. 回顾XOR问题:

        如图,如何对XOR面进行分割以划分四个输入 x 对应的输出 y 呢?

· 思路:采用两个分类器分类,每次分出两个输入 x,再借助这两个分类从而分出 y

        即采用同或运算,当两次分类的值相同时,输出为1;当两次分类的值不同时,输出为0.

        · 蓝色的线将1、3赋值1,2、4赋值0,从而分隔开;黄色的线将1、2赋值1,3、4赋值0;

        · 那么,如果两次赋值相同,即表示它们是第一类;不同表示他们是第二类,由此分类。

2. 如何利用感知机解决XOR问题

由上述原理可得,既然一层感知机无法处理XOR问题分类,那么可以用多个感知机函数来进行处理。用好几层分类多次,最后对之前的分类结果求和取一个算法,就得到了最终的分类结果。

二、多层感知机的代码实现

代码:

import torch
from torch import nn
from d2l import torch as d2l
# 继续使用fashion_mnist数据集进行分类操作,定义小批量数据
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)# 每张图片为28x28=784像素值,可看作784个特征值的具有10个类别的分类数据集
# 首先实现一个具有单隐藏层的多层感知机,包含256个隐藏单元,有输入->隐藏->输出三层
# W1: 输入层到隐藏层的权重矩阵,大小为 (num_inputs, num_hiddens)
# b1: 隐藏层的偏置项,大小为 (num_hiddens,)
# W2: 隐藏层到输出层的权重矩阵,大小为 (num_hiddens, num_outputs)
# b2: 输出层的偏置项,大小为 (num_outputs,)
# nn.Parameter 表示这些变量是模型参数,需要在训练过程中进行更新
# 乘以 0.01 是为了缩小初始化值的范围,有助于训练的稳定性
num_inputs, num_outputs, num_hiddens = 784, 10, 256W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))params = [W1, b1, W2, b2]# 实现ReLU激活函数,返回max(0, x)
def relu(X):a = torch.zeros_like(X)return torch.max(X, a)# 实现模型,将输入的二维图像转化为一个一维向量,长度为num_inputs
def net(X):X = X.reshape((-1, num_inputs))H = relu(X@W1 + b1)  # 这里“@”代表矩阵乘法return (H@W2 + b2)# 实现损失函数
# 由于实现了softmax损失函数,使得不必在输出层调用sigmoid激活函数将输出值收缩到概率区间
# Softmax激活函数是sigmoid的推广,用于多分类问题的输出层。它会将输出归一化为概率分布,使得所有类别的预测概率总和为1
loss = nn.CrossEntropyLoss(reduction='none')# 训练模型,迭代10个周期,学习率设定为0.1
num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)# 应用模型进行测试与评估
d2l.predict_ch3(net, test_iter)
http://www.lryc.cn/news/240882.html

相关文章:

  • 2023年中职“网络安全“—Web 渗透测试①
  • Android——资源IDnonFinalResIds和“Attribute value must be constant”错误
  • 批量创建表空间数据文件(DM8:达梦数据库)
  • 简单聊聊加密和加签的关系与区别
  • 视频转码方法:多种格式视频批量转FLV视频的技巧
  • 【Java 进阶篇】Redis 数据结构:轻松驾驭多样性
  • 东用科技智能公交识别系统无线传输方案
  • Django批量插入数据及分页器
  • PHP 语法||PHP 变量
  • 【python基础(四)】if语句详解
  • Spring Boot中常用的参数传递注解
  • Quartz .Net 的简单使用
  • 面试Java笔试题精选解答
  • 使用Python画一棵树
  • nginx学习(4)Nginx 负载均衡
  • WSL登录时提示nsenter: cannot open /proc/320/ns/time: No such file or directory的解决办法
  • git修改远程分支名称
  • Django 入门学习总结7-静态文件管理
  • 游戏开发引擎Cocos Creator和Unity如何对接广告-AdSet聚合广告平台
  • 振南技术干货集:制冷设备大型IoT监测项目研发纪实(4)
  • Android线程优化——整体思路与方法
  • 论防火墙的体系结构
  • BeansTalkd 做消息队列服务
  • csv文件添加文件内容和读取
  • 关于禅道的安装配置以及项目管理、团队协同工作
  • 使用Wireshark提取流量中图片方法
  • C#,简单修改Visual Studio 2022设置以支持C#最新版本的编译器,尊享编程之趣
  • 小程序Tab栏与页面滚动联动
  • Java,数据结构与集合源码,关于List接口的实现类(ArrayList、Vector、LinkedList)的源码剖析
  • 算法基础(python版本)