当前位置: 首页 > news >正文

pytorch-01

加载mnist数据集

one-hot编码实现

import numpy as np
import torch
x_train = np.load("../dataset/mnist/x_train.npy") # 从网站提前下载数据集,并解压缩
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x = torch.tensor(y_train_label[:5],dtype=torch.int64)  # 获取前5个样本的标签数据
# 定义一个张量输入,因为此时有 5 个数值,且最大值为9,类别数为10
# 所以我们可以得到 y 的输出结果的形状为 shape=(5,10),即5行12列
y = torch.nn.functional.one_hot(x, 10)  # 一个参数张量x,10为类别数
print(y)

对于拥有6000个样本的MNIST数据集来说,标签就是一个6000\times 10大小的矩阵张量。

多层感知机模型

#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten()  # 拉平图像矩阵self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28*28,312),   # 输入大小为28*28,输出大小为312维的线性变换层torch.nn.ReLU(),   # 激活函数层torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10)  # 最终输出大小为10,对应one-hot标签维度)def forward(self, input):   # 构建网络x = self.flatten(input)  #拉平矩阵为1维logits = self.linear_relu_stack(x) # 多层感知机return logits

损失函数

优化函数

model = NeuralNetwork()
loss_fu = torch.nn.CrossEntropyLoss() # 交叉熵损失函数,内置了softmax函数,
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数loss = loss_fu(pred,label_batch)  # 计算损失

完整模型

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU编
import torch
import numpy as npbatch_size = 320                        #设定每次训练的批次数
epochs = 1024                           #设定训练次数#device = "cpu"                         #Pytorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda"                         #在这里读者默认使用GPU,如果读者出现运行问题可以将其改成cpu模式#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = torch.nn.Flatten()self.linear_relu_stack = torch.nn.Sequential(torch.nn.Linear(28*28,312),torch.nn.ReLU(),torch.nn.Linear(312, 256),torch.nn.ReLU(),torch.nn.Linear(256, 10))def forward(self, input):x = self.flatten(input)logits = self.linear_relu_stack(x)return logitsmodel = NeuralNetwork()
model = model.to(device)                #将计算模型传入GPU硬件等待计算
torch.save(model, './model.pth')
#model = torch.compile(model)            #Pytorch2.0的特性,加速计算速度
loss_fu = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数#载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")train_num = len(x_train)//batch_size#开始计算
for epoch in range(20):train_loss = 0for i in range(train_num):start = i * batch_sizeend = (i + 1) * batch_sizetrain_batch = torch.tensor(x_train[start:end]).to(device)label_batch = torch.tensor(y_train_label[start:end]).to(device)pred = model(train_batch)loss = loss_fu(pred,label_batch)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()  # 记录每个批次的损失值# 计算并打印损失值train_loss /= train_numaccuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_sizeprint("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

可视化模型结构和参数

model = NeuralNetwork()
print(model)

是对模型具体使用的函数及其对应的参数进行打印。

格式化显示:

param = list(model.parameters())
k=0
for i in param:l = 1print('该层结构:'+str(list(i.size())))for j in i.size():l*=jprint('该层参数和:'+str(l))k = k+l
print("总参数量:"+str(k))

模型保存

model = NeuralNetwork()
torch.save(model, './model.pth')

netron可视化

安装:pip install netron

运行:命令行输入netron

打开:通过网址http://localhost:8080打开

打开保存的模型文件model.pth:

 

 点击颜色块,可以显示详细信息:

http://www.lryc.cn/news/386707.html

相关文章:

  • 梦想CAD二次开发
  • Eureka的介绍与使用
  • ChatGPT之母:AI自动化将取代人类,创意性工作或将消失
  • 【深度学习驱动流体力学】湍流仿真到深度学习湍流预测
  • 如何从0构建一款类似pytest的工具
  • 6.27-6.29 旧c语言
  • Unidbg调用-补环境V3-Hook
  • 从AICore到TensorCore:华为910B与NVIDIA A100全面分析
  • Edge 浏览器退出后,后台占用问题
  • 实验八 T_SQL编程
  • 【爆肝34万字】从零开始学Python第2天: 判断语句【入门到放弃】
  • React 19 新特性集合
  • 耐高温水位传感器有哪些
  • Symfony国际化与本地化:打造多语言应用的秘诀
  • ApolloClient GraphQL 与 ReactNative
  • 【贡献法】2262. 字符串的总引力
  • C#基于SkiaSharp实现印章管理(3)
  • 如何理解泛型的编译期检查
  • 计算机组成原理:海明校验
  • 信息学奥赛初赛天天练-39-CSP-J2021基础题-哈夫曼树、哈夫曼编码、贪心算法、满二叉树、完全二叉树、前中后缀表达式转换
  • 第11章 规划过程组(收集需求)
  • 探索WebKit的守护神:深入Web安全策略
  • unity ScrollRect裁剪ParticleSystem粒子
  • 凤仪亭 | 第7集 | 大丈夫生居天地之间,岂能郁郁久居人下 | 司徒一言,令我拨云见日,茅塞顿开 | 三国演义 | 逐鹿群雄
  • React实战学习(一)_棋盘设计
  • 【LeetCode】每日一题:三数之和
  • 逆风而行:提升逆商,让困难成为你前进的动力
  • 新能源汽车CAN总线故障定位与干扰排除的几个方法
  • 【涵子来信】——社交宝典:克服你心中的内向,世界总有缺陷
  • LabVIEW项目外协时选择公司与个人兼职的比较