当前位置: 首页 > news >正文

机器学习之多层感知机 MLP简洁实现 《动手深度学习》实例


🎈 作者:Linux猿

🎈 简介:CSDN博客专家🏆,华为云享专家🏆,Linux、C/C++、云计算、物联网、面试、刷题、算法尽管咨询我,关注我,有问题私聊!

🎈 欢迎小伙伴们点赞👍、收藏⭐、留言💬


本篇文章主要介绍《动手深度学习》实例中的多层感知机 MLP 的简洁实现。

一、代码实现

多层感知机(MLP)的简洁实现如下所示。

import torch
from torch import nn
from d2l import torch as d2l'''
1. 设置网络模型
'''
# nn.Sequential() 用于网络模型的各层
# nn.Flatten() 它用来将输入张量展平为 [batch_size, features] 的形式
# nn.Linear() 用于设置网络中的全连接层,需要注意的是全连接层的输入与输出都是二维张量
# https://www.cnblogs.com/douzujun/p/13366939.html
net = nn.Sequential(nn.Flatten(),   # 输入的是 [batch_size, n, m], 经过 nn.Flatten 后变为 [batch_size, nxm], 其中, n和m分别为图像的行和列# [batch_size, in_features], [batch_size, out_features]nn.Linear(784, 256),# 激活函数 ReLU(x) = max(0,x)nn.ReLU(),# [batch_size, in_features], [batch_size, out_features]nn.Linear(256, 10))'''
2. 设置和应用权重
'''
# 用于初始化权重
def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)# 用于初始化权重
net.apply(init_weights)'''
3. 设置超参数、损失函数、优化函数
'''
batch_size, lr, num_epochs = 256, 0.1, 10
# 交叉熵损失函数,一般用于多分类
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)'''
4. 训练模型
'''
# 获取训练集和测试集
# load_data_fashion_mnist 函数加载 Fashion-MNIST 数据集(服饰数据集)
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)# 训练模型
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

二、代码解析

在上述代码中,需要注意的是 nn.Flatten() 函数,默认情况下将输入的多维矩阵除第一维度外的部分扩展为一维,例如:假设有 [n,m,k] 的多维矩阵,调用 nn.Flatten() 后变成 [n, m*k] 的矩阵。上述代码中,大多数都是调用的封装的函数,可以通过调节超参数以及替换对应的函数(例如:激活函数替换)来提升训练精度。


🎈 感觉有帮助记得「一键三连支持下哦!有问题可在评论区留言💬,感谢大家的一路支持!🤞猿哥将持续输出「优质文章回馈大家!🤞🌹🌹🌹🌹🌹🌹🤞


http://www.lryc.cn/news/217394.html

相关文章:

  • 使用C#在Windows上调用7-zip压缩文件
  • 京东数据平台:2023年Q3季度黄金市场数据分析
  • https原理
  • FFmpeg直播能力更新计划与新版本发布
  • 面试算法55:二叉搜索树迭代器
  • Linux Crontab 定时任务
  • HiveSQL高级进阶技巧
  • 【Flutter】Flutter 动画深入解析(1):掌握 AnimationController 的使用
  • 安装富文本组件
  • Tomcat下载地址(详细)
  • 领星ERP如何无需API开发轻松连接OA、电商、营销、CRM、用户运营、推广、客服等近千款系统
  • Django实战项目-学习任务系统-自定义URL拦截器
  • [已解决]该主机与 Cloudera Manager Server 失去联系的时间过长。 该主机未与 Host Monitor 建立联系。
  • 通过在Z平面放置零极点的来设计数字滤波器
  • linux环境docker部署nginx对生产日志按日切割并压缩处理
  • 【Spring Boot】发送邮件功能
  • ELK问题整理
  • 《黑客帝国:破解编程密码》——探索编程世界的奥秘
  • 【优选算法系列】【专题六模拟】第一节.1576. 替换所有的问号和495. 提莫攻击
  • 路由器基础(十二):IPSEC VPN配置
  • Python 获取cpu、内存利用率
  • Apache ECharts简介和相关操作
  • 怎么看待工信部牵头推动人形机器人发展
  • Hikari源码分析
  • 修改YOLOv5的模型结构
  • React 与 React Native 区别
  • Android 12.0 系统system模块开启禁用adb push和adb pull传输文件功能
  • 基于单片机的衣物消毒清洗机系统设计
  • 将 UniLinks 与 Flutter 集成(安卓 AppLinks + iOS UniversalLinks)
  • Spring-Spring 之底层架构核心概念解析