当前位置: 首页 > news >正文

PyTorch学习笔记(十七)——完整的模型验证(测试,demo)套路

完整代码:

import torch
import torchvision
from PIL import Image
from torch import nnimage_path = "../imgs/dog.png"
image = Image.open(image_path)
print(image)# 因为png格式是四个通道,除了RGB三通道外,还有一个透明度通道
image = image.convert("RGB")
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
image = transform(image)
print(image.shape)class MyNN(nn.Module):def __init__(self):super(MyNN, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xmodel = torch.load("mynn_0.pth")
print(model)image = torch.reshape(image,(1,3,32,32))
model.eval()with torch.no_grad():output = model(image.cuda())
print(output)
print(output.argmax(1))

 采用GPU训练的模型,两种方法

(1)在CPU上加载,要从GPU映射到CPU,即把model = torch.load("mynn_9.pth")改为:

model = torch.load("mynn_9.pth",map_location=torch.device('cpu'))

(2)将image转到GPU中,即把output = model(image)改为:

output = model(image.cuda())

 

 预测错误的原因可能是训练次数不够多

 改成:

model = torch.load("mynn_9.pth")

 

 

 

 

http://www.lryc.cn/news/136393.html

相关文章:

  • WPF开篇
  • linux 压缩解压缩
  • centos9 mysql8修改数据库的存储路径
  • 【C++】<Windows编程中消息即事件的处理>
  • 数据库SQL语句使用
  • 从零开始 Spring Cloud 12:Sentinel
  • @Resurce和@Autowired的区别
  • ResNet简介
  • 了解单例模式,工厂模式(简单易懂)
  • 【中危】 Apache NiFi 连接 URL 验证绕过漏洞 (CVE-2023-40037)
  • 【Git版本控制工具使用---讲解一】
  • NLP | 基于LLMs的文本分类任务
  • 攻防世界-base÷4
  • 【Java转Go】快速上手学习笔记(三)之基础篇二
  • 【vue 引入pinia与pinia的详细使用】
  • USACO18DEC Fine Dining G
  • fckeditor编辑器的两种使用方法
  • 数据结构,查找算法(二分,分块,哈希)
  • C++(Qt)软件调试---gdb调试入门用法(12)
  • shell和Python 两种方法分别画 iostat的监控图
  • 设计模式(9)建造者模式
  • PHP 创业感悟交流平台系统mysql数据库web结构apache计算机软件工程网页wamp
  • 工作流程引擎之flowable(集成springboot)
  • leetcode54. 螺旋矩阵(java)
  • go gorm 查询
  • Flutter GetXController 动态Tabbar 报错问题
  • Redis(缓存预热,缓存雪崩,缓存击穿,缓存穿透)
  • UE4/5Niagara粒子特效学习(使用UE5.1,适合新手)
  • from moduleA import * 语句 和import moduleA 的区别
  • 【leetcode 力扣刷题】交换链表中的节点