当前位置: 首页 > news >正文

Pytorch搭建AlexNet 预测实现

1.导包

import torch
import matplotlib.pyplot as plt
import json
from model import AlexNet
from PIL import Image
from torchvision import transforms

2.数据预处理

data_transform = transforms.Compose([transforms.Resize((224, 224)),  # 将图片重新裁剪transforms.ToTensor(),  # 转化为tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  # 标准化数据

 3.加载测试图片

# load image
img = Image.open("1.jpeg")  # 网上随便下载,放到好找的路径下
plt.imshow(img)   # 直接载入图像
img = data_transform(img)  在预处理过程中吧channel提到前面
img = torch.unsqueeze(img, dim=0)  # 添加batch维度

4.读取分类文件

# read class_indent
try:# 读取保存在json文件中索引对应的类别名称json_file = open('./class_indices,json', 'r')class_indict = json.load(json_file)  # 将json文件解码成字典格式
except Exception as e:print(e)exit(-1)

5.初始化网络

output = torch.squeeze(model(img)):先将图片通过正向传播得到输出,再把输出的batch压缩

predict = torch.softmax(output, dim=0):通过softmax得到一个概率分布

predict_cla = torch.argmax(predict).numpy():找到概率最大处所对应的索引值

print将类别名称和预测概率输出

# create model
model = AlexNet(num_classes=5)
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))  # 载入网络模型
model.eval()  # 关闭dropout
with torch.no_grad():output = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

 6.预测结果

容易把玫瑰识别成郁金香,把蒲公英识别成向日葵,郁金香,向日葵,小雏菊可以很好的识别出来,模型的准确率还是有点低。大家自己尝试测试一下吧哈哈。

 PyTorch搭建AlexNet网络合集:
PyTorch搭建AlexNet网络模型-CSDN博客

PyTorch搭建AlexNet训练集-CSDN博客

Pytorch搭建AlexNet 预测实现-CSDN博客

http://www.lryc.cn/news/316687.html

相关文章:

  • 笔记:使用parfile进行的数据导入导出
  • 基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的行人跌倒检测系统(深度学习+UI界面+完整训练数据集)
  • Ubuntu 14.04:PaddleOCR基于PaddleServing的在线服务化部署(失败)
  • Java JUC 笔记(2)
  • webpack5高级--02_提升打包构建速度
  • MAC M芯片 Anaconda安装
  • 【JS】自动下拉网页刷新,当出现指定关键字,就打印出来
  • 中兴通讯联手新疆移动,开通全疆首个乡农场景700M+900M双频双模基站
  • 爬虫案例4: parsel 模块的运用
  • 数据结构·复杂度
  • 数学建模理论与实践国防科大版
  • Yakit爆破模块应用
  • 【3GPP】【核心网】【5G】NAS连接管理和UE注册管理状态(超详细)
  • 细粒度IP定位参文2(Corr-SLG):A street-level IP geolocation method (2021年)
  • Mac上使用M1或M2芯片的设备安装Node.js时遇到一些问题,比如卡顿或性能问题
  • 学习vue3第四节(ref以及ref相关api)
  • 关于电脑无法开启5G频段热点的解决方案
  • 清理磁盘空间 - Win系统
  • 科技革新的引擎-2024年AI辅助研发趋势
  • 【PTA】L1-021 L1-022 L1-023 L1-024 L1-025(C)第四天
  • Stable Diffusion 如何写好提示词(Prompt)
  • 树莓派Py程序加入开机自启
  • Java EasyExcel注解详解和实战案例
  • AHU 汇编 实验二
  • Spring Boot单元测试与热部署简析
  • 3.12练习题解
  • Java中实现双向链表
  • 【DevOps实战之k8s】使用Prometheus和Grafana监控K8S集群
  • 【读论文】【精读】3D Gaussian Splatting for Real-Time Radiance Field Rendering
  • JVM理解学习