当前位置: 首页 > article >正文

(4)pokeman_用图片对模型进行测试

1、用图片对模型进行测试

编写的测试函数如下:

"""函数说明: 根据训练结果对模型进行测试:param    img_test: 待测试的图片:return:  y: 测试结果,分类序号"""
def model_test_img(model,img_test):model.eval()img = Image.open(img_test).convert('RGB') # resize = transforms.Resize([224,224])x = transforms.Resize([img_resize,img_resize])(img)x = transforms.ToTensor()(x)x = x.to(device)x = x.unsqueeze(0)x = transforms.Normalize(mean,std)(x)# print(x.shape)with torch.no_grad():logits = model(x)pred = logits.argmax(dim=1)return pred

2、自己的resnet18模型进行测试

参考第二篇,自己的resnet18

(2)pokeman分类的例子_chencaw的博客-CSDN博客

代码如下

import  torch
from    torch import optim, nnfrom    torch.utils.data import DataLoader
from    torchvision import transforms,datasetsfrom    resnet import ResNet18
from    PIL import Imageimg_resize = 224
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]# device = torch.device('cuda')
device = torch.device('cpu')"""函数说明: 根据训练结果对模型进行测试:param    img_test: 待测试的图片:return:  y: 测试结果,分类序号"""
def model_test_img(model,img_test):model.eval()img = Image.open(img_test).convert('RGB') # resize = transforms.Resize([224,224])x = transforms.Resize([img_resize,img_resize])(img)x = transforms.ToTensor()(x)x = x.to(device)x = x.unsqueeze(0)x = transforms.Normalize(mean,std)(x)# print(x.shape)with torch.no_grad():logits = model(x)pred = logits.argmax(dim=1)return preddef main():#(1)如用anaconda激活你自己的环境# conda env list# conda activate chentorch_cp310#分类名称class_name = ['bulbasaur', 'charmander', 'mewtwo', 'pikachu', 'squirtle']# image_file = "D:/pytorch_learning2022/data/pokeman/train/bulbasaur/00000002.jpg"image_file = "D:/pytorch_learning2022/data/pokeman/test/pikachu/00000117.jpg"model = ResNet18(5).to(device)model.load_state_dict(torch.load('best_scratch.mdl'))y = model_test_img(model,image_file)print(y)print("detect result is: ",class_name[y])if __name__ == '__main__':main()

3、迁移学习的resnet18模型进行测试

(1)将图片的转换封装为一个transforms更好

tf = transforms.Compose([#匿名函数lambda x:Image.open(x).convert('RGB'), # string path= > image datatransforms.Resize((img_resize, img_resize)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

(2)调用方式也很简单

  x =tf(img_test)

(3)完整的代码实现

import  torch
from    torch import optim, nnfrom    torch.utils.data import DataLoader
from    torchvision import transforms,datasets# from    resnet import ResNet18
from    torchvision.models import resnet18
from    PIL import Image
# from    utils import Flattenimg_resize = 224
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]# device = torch.device('cuda')
device = torch.device('cpu')tf = transforms.Compose([#匿名函数lambda x:Image.open(x).convert('RGB'), # string path= > image datatransforms.Resize((img_resize, img_resize)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])"""函数说明: 根据训练结果对模型进行测试:param    img_test: 待测试的图片:return:  y: 测试结果,分类序号"""
def model_test_img(model,img_test):model.eval()# img = Image.open(img_test).convert('RGB') # resize = transforms.Resize([224,224])# x = transforms.Resize([img_resize,img_resize])(img)# x = transforms.ToTensor()(x)x =tf(img_test)x = x.to(device)x = x.unsqueeze(0)#x = transforms.Normalize(mean,std)(x)  #上次不小心忘了,chen20221104# print(x.shape)with torch.no_grad():logits = model(x)pred = logits.argmax(dim=1)return preddef main():#(1)如用anaconda激活你自己的环境# conda env list# conda activate chentorch_cp310#分类名称class_name = ['bulbasaur', 'charmander', 'mewtwo', 'pikachu', 'squirtle']# image_file = "D:/pytorch_learning2022/data/pokeman/train/bulbasaur/00000002.jpg"image_file = "D:/pytorch_learning2022/data/pokeman/test/pikachu/00000117.jpg"trained_model = resnet18(pretrained=True)model = nn.Sequential(*list(trained_model.children())[:-1], #[b, 512, 1, 1]nn.Flatten(),#   Flatten(), # [b, 512, 1, 1] => [b, 512]nn.Linear(512, 5)).to(device)model.load_state_dict(torch.load('best_transfer.mdl'))y = model_test_img(model,image_file)print(y)print("detect result is: ",class_name[y])if __name__ == '__main__':main()

http://www.lryc.cn/news/2419750.html

相关文章:

  • 什么是TTL电平,什么是CMOS电平
  • “boost::get_property的用法示例“:使用Boost库的get_property方法可以方便地获取C++对象的属性值
  • sockaddr和sockaddr_in结构体、以及inet_ntoa()和inet_addr()函数的用法
  • rownum,row_number區別。 执行顺序
  • 最新BIOS设置中英文对照表
  • P2P原理与实践
  • erpc的设计和工作机制
  • MD5:介绍与应用
  • Win10 VC++6 无法启动此程序,因为计算机中丢失mfc42d.dll 需要提升
  • Vim的全面配置
  • 谈安全测试的重要性
  • Oracle 视图详解
  • 浅谈快速沃尔什变换(FWT)快速莫比乌斯变换(FMT)
  • Android 二级列表控件ExpandableListView 的简单使用
  • FlashFXP的使用
  • stm32平衡小车--(1)JGB-520减速电机+tb6612(附测试代码)
  • Linux磁盘配额(EXT4XFS)
  • html简单网页代码:HTML+CSS茶叶官网网页设计实例 企业网站制作
  • Red5 流媒体技术(初级了解)
  • VRRP原理和配置
  • case when的使用方法
  • 探秘Proton:统一的实时数据分析引擎
  • 不能通过“www.baidu.com”访问百度解决方法
  • Nginx 简单的负载均衡配置示例
  • portlet示例_Java Portlet示例教程
  • C#让程序运行更稳健——异常、调试和测试(代码没看懂)
  • 探索数据的新型画布 - OrientDB Studio 深度解析与应用
  • Editplus如何设置中文页面
  • JAVA开发基础-XML
  • 查看电脑内存个数、主频(工作频率)、容量、位宽等的方法总结