(4)pokeman_用图片对模型进行测试
1、用图片对模型进行测试
编写的测试函数如下:
"""函数说明: 根据训练结果对模型进行测试:param img_test: 待测试的图片:return: y: 测试结果,分类序号"""
def model_test_img(model,img_test):model.eval()img = Image.open(img_test).convert('RGB') # resize = transforms.Resize([224,224])x = transforms.Resize([img_resize,img_resize])(img)x = transforms.ToTensor()(x)x = x.to(device)x = x.unsqueeze(0)x = transforms.Normalize(mean,std)(x)# print(x.shape)with torch.no_grad():logits = model(x)pred = logits.argmax(dim=1)return pred
2、自己的resnet18模型进行测试
参考第二篇,自己的resnet18
(2)pokeman分类的例子_chencaw的博客-CSDN博客
代码如下
import torch
from torch import optim, nnfrom torch.utils.data import DataLoader
from torchvision import transforms,datasetsfrom resnet import ResNet18
from PIL import Imageimg_resize = 224
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]# device = torch.device('cuda')
device = torch.device('cpu')"""函数说明: 根据训练结果对模型进行测试:param img_test: 待测试的图片:return: y: 测试结果,分类序号"""
def model_test_img(model,img_test):model.eval()img = Image.open(img_test).convert('RGB') # resize = transforms.Resize([224,224])x = transforms.Resize([img_resize,img_resize])(img)x = transforms.ToTensor()(x)x = x.to(device)x = x.unsqueeze(0)x = transforms.Normalize(mean,std)(x)# print(x.shape)with torch.no_grad():logits = model(x)pred = logits.argmax(dim=1)return preddef main():#(1)如用anaconda激活你自己的环境# conda env list# conda activate chentorch_cp310#分类名称class_name = ['bulbasaur', 'charmander', 'mewtwo', 'pikachu', 'squirtle']# image_file = "D:/pytorch_learning2022/data/pokeman/train/bulbasaur/00000002.jpg"image_file = "D:/pytorch_learning2022/data/pokeman/test/pikachu/00000117.jpg"model = ResNet18(5).to(device)model.load_state_dict(torch.load('best_scratch.mdl'))y = model_test_img(model,image_file)print(y)print("detect result is: ",class_name[y])if __name__ == '__main__':main()
3、迁移学习的resnet18模型进行测试
(1)将图片的转换封装为一个transforms更好
tf = transforms.Compose([#匿名函数lambda x:Image.open(x).convert('RGB'), # string path= > image datatransforms.Resize((img_resize, img_resize)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
(2)调用方式也很简单
x =tf(img_test)
(3)完整的代码实现
import torch
from torch import optim, nnfrom torch.utils.data import DataLoader
from torchvision import transforms,datasets# from resnet import ResNet18
from torchvision.models import resnet18
from PIL import Image
# from utils import Flattenimg_resize = 224
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]# device = torch.device('cuda')
device = torch.device('cpu')tf = transforms.Compose([#匿名函数lambda x:Image.open(x).convert('RGB'), # string path= > image datatransforms.Resize((img_resize, img_resize)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])"""函数说明: 根据训练结果对模型进行测试:param img_test: 待测试的图片:return: y: 测试结果,分类序号"""
def model_test_img(model,img_test):model.eval()# img = Image.open(img_test).convert('RGB') # resize = transforms.Resize([224,224])# x = transforms.Resize([img_resize,img_resize])(img)# x = transforms.ToTensor()(x)x =tf(img_test)x = x.to(device)x = x.unsqueeze(0)#x = transforms.Normalize(mean,std)(x) #上次不小心忘了,chen20221104# print(x.shape)with torch.no_grad():logits = model(x)pred = logits.argmax(dim=1)return preddef main():#(1)如用anaconda激活你自己的环境# conda env list# conda activate chentorch_cp310#分类名称class_name = ['bulbasaur', 'charmander', 'mewtwo', 'pikachu', 'squirtle']# image_file = "D:/pytorch_learning2022/data/pokeman/train/bulbasaur/00000002.jpg"image_file = "D:/pytorch_learning2022/data/pokeman/test/pikachu/00000117.jpg"trained_model = resnet18(pretrained=True)model = nn.Sequential(*list(trained_model.children())[:-1], #[b, 512, 1, 1]nn.Flatten(),# Flatten(), # [b, 512, 1, 1] => [b, 512]nn.Linear(512, 5)).to(device)model.load_state_dict(torch.load('best_transfer.mdl'))y = model_test_img(model,image_file)print(y)print("detect result is: ",class_name[y])if __name__ == '__main__':main()