当前位置: 首页 > news >正文

Day 33: 动手实现一个简单的 MLP

知识点回顾:

1.PyTorch和cuda的安装
2.查看显卡信息的命令行命令(cmd中使用)
3.cuda的检查
4.简单神经网络的流程

  • a.数据预处理(归一化、转换成张量)
  • b.模型的定义
    • i.继承nn.Module类
    • ii.定义每一个层
    • iii.定义前向传播流程

c.定义损失函数和优化器
d.定义训练流程
e.可视化loss过程

torchcuda的安装及anaconda的安装与环境配置自行搜索相关教程完成


代码实现

首先检查torch是否安装

import torchif torch.cuda.is_available():print(f'可用的GPU数量: {torch.cuda.device_count()}')print(f'当前设备: {torch.cuda.current_device()}')

导入鸢尾花数据集

# 导入鸢尾花数据集
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import numpy as npiris = load_iris()
X = iris.data
y = iris.targetX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)print(X_train.shape)
print(y_train.shape)
print('\n')
print(X_test.shape)
print(y_test.shape)

数据归一化

# 数据归一化
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()# train_set & test_set 是相同的缩放
X_train = scaler.fit_transform(X_train) # fit_transform会重新计算数据集中的最大最小值进行归一化操作
X_test = scaler.transform(X_test) # 注意 不是 fit_transform函数 确保训练集和测试集应用了相同的缩放

类型转换

# 将数据转换为 Pytorch 张量 (Pytorch 使用张量进行训练)X_train = torch.FloatTensor(X_train)
y_train = torch.LongTensor(y_train)
X_test = torch.FloatTensor(X_test)
y_test = torch.LongTensor(y_test)

模型定义

# 模型架构定义 简单的全连接神经网络模型 包含一个输入层 一个隐藏层 一个输出层
# 并定义层数 前向传播顺序import torch.nn as nn
import torch.optim as optim# define model
class MyMLP(nn.Module):def __init__(self):super().__init__()self.input_layer = nn.Linear(4, 10)self.relu = nn.ReLU()self.output_layer = nn.Linear(10, 3)def forward(self, x):out = self.output_layer(self.relu(self.input_layer(x)))return out
# 实例化
model = MyMLP()

定义损失函数和优化函数

loss = nn.CrossEntropyLoss()optimizer = optim.SGD(model.parameters(), lr=0.01)

定义训练函数并绘制损失图像

def train(epochs: int):losses = []import matplotlib.pyplot as pltfor e in range(epochs):# 前向传播计算得到结果output = model.forward(X_train)# 跟据预测结果与标签计算损失los = loss(output, y_train)# 梯度清零 Pytorch自动积累梯度 因此每次迭代都需要清零梯度optimizer.zero_grad()# 开始反向传播los.backward()# 更次年模型参数optimizer.step()# 记录损失losses.append(los.item())if (e + 1) % 10 == 0:print(f'Epoch: [{e + 1}/{epochs}], Loss: {los.item():.4f}')# 训练结果可视化plt.plot(range(epochs), losses)plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss Over All Epochs')plt.show()

执行训练train(20000)

输出为
在这里插入图片描述

在这里插入图片描述
@浙大疏锦行

http://www.lryc.cn/news/611869.html

相关文章:

  • 《深入浅出Embedding》这本书
  • 【LeetCode 热题 100】347. 前 K 个高频元素——(解法三)桶排序
  • 深入理解C++中的stack、queue和priority_queue
  • 【docker】namespace 命名空间
  • LangChain4j检索增强生成RAG
  • Anthropic于本周一推出了其旗舰模型的升级版Claude Opus 4.1
  • 第十八天:C++进制之间的转换
  • 17.9 ChatGLM3-6B开源!32K长文本+推理提速45%,多任务性能飙升29.4%
  • Transwell 细胞迁移与侵袭实验:从原理到操作的详细指南
  • VSCode:基础使用 / 使用积累
  • QML开发:QML中的基本元素
  • 大数据之Flume
  • AT32的freertos下modbus TCP移植
  • #C语言——学习攻略:探索内存函数--memcpy、memmove的使用和模拟实现,memset、memcmp函数的使用
  • flex布局:容器的justify-content属性
  • CEH、OSCP、CISP、CISSP 四大网络安全认证攻略
  • 【hot100】无重复字符的最长子串-Python3
  • duiLib 编译时复制资源目录到exe同级目录
  • 推动本地流智能:基于 Apache Kafka 与 Flink 的实时机器学习实践
  • 无需SCADA/OPC,实现直接与西门子PLC Web API通讯实现数据读写(一)
  • Mysql如何迁移数据库数据
  • 【自动驾驶】《Sparse4Dv3 Advancing End-to-End 3D Detection and Tracking》论文阅读笔记
  • 工业协议转换终极武器:EtherCAT转PROFINET网关的连接举例
  • Spring Boot全局异常处理与日志监控实战指南
  • 从Navisworks到定制化BIM系统:HOOPS Exchange如何实现高效3D格式解析?
  • 【公考】----申论篇
  • 测试单节点elasticsearch配置存储压缩后的比率
  • 20250806给PRO-RK3566开发板在Buildroot系统下扩大rootfs分区2GB
  • 移动端网页调试实战,跨设备兼容与触控交互问题排查全流程
  • Class30数据增广