当前位置: 首页 > news >正文

PyTorch快速入门教程【小土堆】之完整模型训练套路

视频地址完整的模型训练套路(一)_哔哩哔哩_bilibili

import torch
import torchvision
from model import *
from torch import nn
from torch.utils.data import DataLoader# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),download=True)
# Length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
# 如果train_data_size=10,训练数据集的长度为:10
# print("训练数据集的长度为: {}".format(train_data_size))
# print("测试数据集的长度为: {}".format(test_data_size))# 利用 DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 创建网络模型
tudui = Tudui()# 损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10for i in range(epoch):print("--------第{}轮训练开始---------".format(i + 1))# 训练步骤开始for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))# 测试步骤开始total_test_loss =0total_accuracy = 0with torch.no_grad():# 保证不会调优for data in test_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss: {}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy / test_data_size))torch.save(tudui, "tudui_{}.pth".format(i))print("模型已保存")
http://www.lryc.cn/news/514033.html

相关文章:

  • 【AIGC】 ChatGPT实战教程:如何高效撰写学术论文引言
  • TTL 传输中过期问题定位
  • 非docker方式部署openwebui过程记录
  • 大模型的prompt的应用二
  • ubuntu 22.04安装ollama
  • 从企业级 RAG 到 AI Assistant,阿里云 Elasticsearch AI 搜索技术实践
  • Redis--高可用(主从复制、哨兵模式、分片集群)
  • 框架(Mybatis配置日志)
  • 人工智能-Python上下文管理器-with
  • 每天40分玩转Django:Django类视图
  • 自动化测试之Pytest框架(万字详解)
  • 基于51单片机(STC32G12K128)和8X8彩色点阵屏(WS2812B驱动)的小游戏《贪吃蛇》
  • 2011-2020年各省粗离婚率数据
  • C++高级编程技巧:模板元编程与性能优化实践
  • Mac 版本向日葵退出登录账号
  • SOLIDWORKS Composer在产品设计、制造与销售中的应用
  • Win11+WLS Ubuntu 鸿蒙开发环境搭建(一)
  • [CSAW/网络安全] Git泄露+命令执行 攻防世界 mfw 解题详析
  • MySQL 锁那些事
  • Linux中常用的基本指令和一些配套的周边知识详解
  • 深入理解Java中的Set集合:特性、用法与常见操作指南
  • Oracle 使用 sql profile 固定执行计划
  • 数字电路期末复习
  • 正则表达式 - 使用总结
  • 通过Xshell远程连接wsl2
  • 【ubuntu】安装OpenSSH服务器
  • CESS 的 2024:赋能 AI,塑造去中心化数据基础
  • Redission红锁
  • 使用 CSS 的 `::selection` 伪元素来改变 HTML 文本选中时的背景颜色
  • Spring Boot AOP日志打印实现