当前位置: 首页 > news >正文

学习pytorch18 pytorch完整的模型训练流程

pytorch完整的模型训练流程

  • 1. 流程
    • 1. 整理训练数据 使用CIFAR10数据集
    • 2. 搭建网络结构
    • 3. 构建损失函数
    • 4. 使用优化器
    • 5. 训练模型
    • 6. 测试数据 计算模型预测正确率
    • 7. 保存模型
  • 2. 代码
    • 1. model.py
    • 2. train.py
  • 3. 结果
    • tensorboard结果
      • 以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值
      • 训练loss
      • 测试loss
      • 测试集正确率
  • 4. 需要注意的细节

1. 流程

1. 整理训练数据 使用CIFAR10数据集

train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)

2. 搭建网络结构

在这里插入图片描述
model.py

3. 构建损失函数

loss_fn = nn.CrossEntropyLoss()

4. 使用优化器

learing_rate = 1e-2 # 0.01
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)

5. 训练模型

output = net(imgs)    # 数据输入模型
loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少
# 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化
optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0
loss.backward()        # 设置计算的损失值的钩子,调用损失的反向传播,计算每个参数结点的参数
optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化  

6. 测试数据 计算模型预测正确率

output = net(imags)
# 计算测试集的正确率
preds = (output.argmax(1)==targets).sum()
accuracy += preds 
rate = accuracy/len(test_data)

调用模型输出tensor 数据类型的 argmax方法, argmax或获取一行或者一列数值中最大数值的下标位置,argmax(0) 是从列的维度取一列数值的最大值的下标,argmax(1) 是从行的维度取一行数值的最大值的下标
output.argmax(1)==targets 会输出如下图最后一行 [false, ture], 对应位置相同则为true,对应位置不同则为false;
调用sum()方法,计算求和,false值为0,true值为1.
最后计算得出测试集整体正确率: rate = accuracy/len(test_data)
在这里插入图片描述

7. 保存模型

torch.save(net, './net_epoch{}.pth'.format(i))

2. 代码

1. model.py

import torch
from torch import nn# 2. 搭建模型网络结构--神经网络
class Cifar10Net(nn.Module):def __init__(self):super(Cifar10Net, self).__init__()self.net = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(kernel_size=2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.net(x)return xif __name__ == '__main__':net = Cifar10Net()input = torch.ones((64, 3, 32, 32))output = net(input)print(output.shape)

2. train.py

import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriterfrom p24_model import *# 1. 准备数据集
# 训练数据
from torch.utils.data import DataLoadertrain_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)
# 测试数据
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)# 查看数据大小--size
print("训练数据集大小:", len(train_data))
print("测试数据集大小:", len(test_data))
# 利用DataLoader来加载数据集
train_loader = DataLoader(dataset=train_data, batch_size=64)
test_loader = DataLoader(dataset=test_data, batch_size=64)# 2. 导入模型结构 创建模型
net = Cifar10Net()# 3. 创建损失函数  分类问题--交叉熵
loss_fn = nn.CrossEntropyLoss()# 4. 创建优化器
# learing_rate = 0.01
# 1e-2 = 1 * 10^(-2) = 0.01
learing_rate = 1e-2
print(learing_rate)
optimizer = torch.optim.SGD(net.parameters(), lr=learing_rate)# 设置训练网络的一些参数
epoch = 10   # 记录训练的轮数
total_train_step = 0  # 记录训练的次数
total_test_step = 0   # 记录测试的次数# 利用tensorboard显示训练loss趋势
writer = SummaryWriter('./train_logs')for i in range(epoch):# 训练步骤开始net.train()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用for data in train_loader:imgs, targets = data  # 获取数据output = net(imgs)    # 数据输入模型loss = loss_fn(output, targets)  # 损失函数计算损失 看计算的输出和真实的标签误差是多少# 优化器开始优化模型  1.梯度清零  2.反向传播  3.参数优化optimizer.zero_grad()  # 利用优化器把梯度清零 全部设置为0loss.backward()        # 设置计算的损失值,调用损失的反向传播,计算每个参数结点的参数optimizer.step()       # 调用优化器的step()方法 对其中的参数进行优化# 优化一次 认为训练了一次total_train_step += 1if total_train_step % 100 == 0:print('训练次数: {}   loss: {}'.format(total_train_step, loss))# 直接打印loss是tensor数据类型,打印loss.item()是打印的int或float真实数值, 真实数值方便做数据可视化【损失可视化】# print('训练次数: {}   loss: {}'.format(total_train_step, loss.item()))writer.add_scalar('train-loss', loss.item(), global_step=total_train_step)# 利用现有模型做模型测试# 测试步骤开始total_test_loss = 0accuracy = 0net.eval()  # 可以加可以不加  只有当模型结构有 Dropout BatchNorml层才会起作用with torch.no_grad():for data in test_loader:imags, targets = dataoutput = net(imags)loss = loss_fn(output, targets)total_test_loss += loss.item()# 计算测试集的正确率preds = (output.argmax(1)==targets).sum()accuracy += preds# writer.add_scalar('test-loss', total_test_loss, global_step=i+1)writer.add_scalar('test-loss', total_test_loss, global_step=total_test_step)writer.add_scalar('test-accracy', accuracy/len(test_data), total_test_step)total_test_step += 1print("---------test loss: {}--------------".format(total_test_loss))print("---------test accuracy: {}--------------".format(accuracy))# 保存每一个epoch训练得到的模型torch.save(net, './net_epoch{}.pth'.format(i))writer.close()

3. 结果

训练数据集大小: 50000
测试数据集大小: 10000
0.01
训练次数: 100   loss: 2.2905373573303223
训练次数: 200   loss: 2.2878968715667725
训练次数: 300   loss: 2.258394718170166
训练次数: 400   loss: 2.1968581676483154
训练次数: 500   loss: 2.0476632118225098
训练次数: 600   loss: 2.002145767211914
训练次数: 700   loss: 2.016021728515625
---------test loss: 316.382279753685--------------
训练次数: 800   loss: 1.8957302570343018
训练次数: 900   loss: 1.8659226894378662
训练次数: 1000   loss: 1.9004186391830444
训练次数: 1100   loss: 1.9708642959594727
......

tensorboard结果

安装tensorboard运行环境

pip install tensorboard
pip install opencv-python
pip install six
tensorboard --logdir=train_logs

以下图片 颜色较浅的线是真实计算的值,颜色较深的线是做了平滑处理的值

训练loss

在这里插入图片描述

测试loss

在这里插入图片描述

测试集正确率

在这里插入图片描述

4. 需要注意的细节

https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module

所有网络层继承于torch.nn.Module, net.train() net.eval() 在模型训练或测试之初 可以加可以不加 只有当模型结构有 Dropout BatchNorml层才会起作用,当模型有这两个网络层的时候,两个代码需要加上。
在这里插入图片描述

在这里插入图片描述

http://www.lryc.cn/news/255943.html

相关文章:

  • 电子学会C/C++编程等级考试2021年09月(五级)真题解析
  • Halcon联合winform显示以及处理
  • 【设计模式-4.3】行为型——责任链模式
  • 单片机语言--C51语言的数据类型以及存储类型以及一些基本运算
  • 《每天一个Linux命令》 -- (5)通过sshkey密钥登录服务器
  • kubernetes的服务发现(二)
  • 【矩阵论】Chapter 4—特征值和特征向量知识点总结复习
  • Linux 进程地址空间
  • websocket vue操作
  • 腾讯云CentOS8 jenkins war安装jenkins步骤文档
  • Linux: glibc: net/if.h vs linux/if.h
  • 使用Android Studio导入Android源码:基于全志H713 AOSP,方便解决编译、编码问题
  • python random详解
  • java-两个列表进行比较,判断那些是需要新增的、删除的、和更新的
  • 【WPF.NET开发】WPF中的对话框
  • NLP项目实战01之电影评论分类
  • 一款可无限扩展的软件定时器开源框架项目代码
  • GRE与顺丰圆通快递盒子
  • 12.Mysql 多表数据横向合并和纵向合并
  • 线性回归与逻辑回归:深入解析机器学习的基石模型
  • 电脑待机怎么设置?让你的电脑更加节能
  • 数据库对象介绍与实践:视图、函数、存储过程、触发器和物化视图
  • arm平台编译so文件回顾
  • 【数据结构】顺序表的定义和运算
  • idea使用maven的package打包时提示“找不到符号”或“找不到包”
  • MetricBeat监控MySQL
  • Child Mind Institute - Detect Sleep States(2023年第一次Kaggle拿到了银牌总结)
  • Esxi7Esxi8设置VMFSL虚拟闪存的大小
  • vue2+electron桌面端一体机应用
  • 目标检测——OverFeat算法解读