当前位置: 首页 > news >正文

GoogLeNet-水果分类

GoogLeNet-水果分类

1.数据集

官方下载地址:https://www.kaggle.com/datasets/karimabdulnabi/fruit-classification10-class?resource=download

备用下载地址:https://www.123684.com/s/xhlWjv-pRAPh

介绍:

十个类别:苹果、橙色、鳄梨、猕猴桃、芒果、凤梨、草莓、香蕉、樱桃、西瓜

2.训练

import copy
import time
import torch
from torch import nn
import torchvision
from torchvision import transforms
import torch.utils.data as Data
import matplotlib.pyplot as plt
import numpy as np
from model import GoogLeNet, Inception
import pandas as pddef train_val_data_process():ROOT_TARIN = './02框架学习/04经典卷积神经网络与实战-pao哥/06_GoogLeNet_fruit/dataset/train'# 定义处理训练集的数据 Tensor会将数据转换为0-1之间的数据train_transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])# 加载数据集train_data = torchvision.datasets.ImageFolder(root=ROOT_TARIN, transform=train_transform)# 划分训练集验证集train_data, val_data = Data.random_split(dataset=train_data, lengths=[round(0.8*len(train_data)), round(0.2*len(train_data))])train_dataloader = Data.DataLoader(dataset=train_data, batch_size=64,shuffle=True,num_workers=3)val_dataloader = Data.DataLoader(dataset=val_data,batch_size=64,shuffle=True,num_workers=3)return train_dataloader, val_dataloaderdef train_model_process(model, train_dataloader, val_dataloader, epochs):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')optimizer = torch.optim.Adam(model.parameters(), lr=0.001)loss_fn = nn.CrossEntropyLoss()# 将模型放入到训练设备model = model.to(device)# 最佳权重best_model_wts = copy.deepcopy(model.state_dict)# 参数best_acc = 0.0train_loss_all = []val_loss_all = []train_acc_all = []val_acc_all = []since = time.time()for epoch in range(epochs):print(f'epoch:{epoch} / {epochs-1}')print('-'*10)# 初始化参数train_loss = 0.0train_corrects = 0.0val_loss = 0.0val_corrects = 0.0train_num = 0val_num = 0# 对每个batch进行训练for step, (x, y) in enumerate(train_dataloader):x = x.to(device)y = y.to(device)model.train()# 前向传播output = model(x)# 查找每一行中最大的行标pre_lab = torch.argmax(output, dim=1)loss = loss_fn(output, y)optimizer.zero_grad()loss.backward()optimizer.step()# 对损失函数进行累加train_loss += loss.item() * x.size(0)# 如果预测正确,则准确度加1train_corrects += torch.sum(pre_lab == y.data)# 当前用于训练的样本数量train_num += x.size(0)# 对验证集进行验证for step, (x, y) in enumerate(val_dataloader):x = x.to(device)y = y.to(device)model.eval()output = model(x)# 查找每一行对应的最大的行标pre_lab = torch.argmax(output, dim=1)# 计算每一个batch对应的损失loss = loss_fn(output, y)# 对损失值进行累加val_loss += loss.item() * x.size(0)# 如果预测正确,则准确度加1val_corrects += torch.sum(pre_lab == y.data)# 当前用于训练的样本数量val_num += x.size(0)train_loss_all.append(train_loss / train_num)train_acc_all.append(train_corrects.item() / train_num)val_loss_all.append(val_loss / val_num)val_acc_all.append(val_corrects.item() / val_num)print(f'epoch:{epoch} train loss:{train_loss_all[-1]:.4f} train acc:{train_acc_all[-1]:.4f}')print(f'epoch:{epoch} val loss:{val_loss_all[-1]:.4f} val acc:{val_acc_all[-1]:.4f}')# 寻找最高准确度if val_acc_all[-1] > best_acc:# 保存当前最高的准确度和对应的权重best_acc = val_acc_all[-1]best_model_wts = copy.deepcopy(model.state_dict())# 训练耗时time_use = time.time() - sinceprint(f'训练和验证耗费的时间{time_use / 60:.0f}m:{time_use % 60:.0f}s')# 选择最优的模型torch.save(best_model_wts, 'best_model.pth')train_process = pd.DataFrame(data = {'epoch':range(epochs),'train_loss_all':train_loss_all,'train_acc_all':train_acc_all,'val_loss_all':val_loss_all,'val_acc_all':val_acc_all})return train_process# 绘制
def matplot(train_process):plt.figure(figsize=(12, 4))plt.subplot(1,2,1)plt.plot(train_process['epoch'], train_process.train_loss_all, 'ro-', label='train loss')plt.plot(train_process['epoch'], train_process.val_loss_all, 'bo-', label='val loss')plt.legend()plt.xlabel('epoch')plt.ylabel('loss')plt.subplot(1,2,2)plt.plot(train_process['epoch'], train_process.train_acc_all, 'ro-', label='train acc')plt.plot(train_process['epoch'], train_process.val_acc_all, 'bo-', label='val acc')plt.legend()plt.xlabel('epoch')plt.ylabel('acc') plt.show() if __name__ == '__main__':# 模型实例化model = GoogLeNet(Inception)train_dataloader, val_dataloader = train_val_data_process()train_process = train_model_process(model, train_dataloader, val_dataloader, 51)matplot(train_process)

训练了50轮,但是最终的效果不是很好

在这里插入图片描述

在这里插入图片描述

3.测试

"""
@author:Lunau
@file:model_test.py
@time:2024/09/19
"""
import torch
from torch import nn
from torchvision import transforms
import torchvision
import torch.utils.data as Data
import matplotlib.pyplot as plt
import numpy as np
from model import GoogLeNet, Inceptiondef test_data_process():ROOT_TEST = './02框架学习/04经典卷积神经网络与实战-pao哥/05_GoogLeNet_catAndDog/dataset/test'# 处理测试集的数据test_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])# 加载数据集test_data = torchvision.datasets.ImageFolder(root=ROOT_TEST, transform=test_transform)test_dataloader = Data.DataLoader(dataset=test_data,batch_size=32, shuffle=True,num_workers=3,)return test_dataloaderdef test_model_process(model, test_dataloader):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = model.to(device)# 初始化参数test_corrects = 0.0test_num = 0# 只进行前向传播with torch.no_grad():try:for x, y in test_dataloader:x = x.to(device)y = y.to(device)model.eval()# 输出结果,是概率值output = model(x)# 查找每一行中最大的行标prd_lab = torch.argmax(output, dim=1)# 预测正确的数量test_corrects += torch.sum(prd_lab == y.data)# 将所有测试的样本数进行累加test_num += x.size(0)except Exception:print('error and skip')# 计算测试的准确率test_acc = test_corrects.item() / test_numprint(f'测试的准确率为:{test_acc}')def imshow(img):img = img / 2 + 0.5  # 将图像归一化还原npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()if __name__ == '__main__':# 加载模型结构device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = GoogLeNet(Inception).to(device)# 加载模型权重model.load_state_dict(torch.load('02框架学习\\04经典卷积神经网络与实战-pao哥\\05_GoogLeNet_catAndDog\\best_model.pth', weights_only=True))test_loader = test_data_process()test_model_process(model, test_loader)

测试的准确率确实有些低,训练轮次多一些可能会好一点

在这里插入图片描述

4.预测

对这张图片进行预测

在这里插入图片描述

在这里插入图片描述

"""
@author:Lunau
@file:model_test.py
@time:2024/09/19
"""
import torch
from torch import nn
from torchvision import transforms
import torch.utils.data as Data
import matplotlib.pyplot as plt
import numpy as np
from model import GoogLeNet, Inception
from PIL import Image# 推理单张图片
if __name__ == '__main__':# 加载模型结构device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = GoogLeNet(Inception).to(device)# 加载模型权重model.load_state_dict(torch.load('02框架学习\\04经典卷积神经网络与实战-pao哥\\06_GoogLeNet_fruit\\best_model.pth', weights_only=True))image = Image.open('02框架学习\\04经典卷积神经网络与实战-pao哥\\06_GoogLeNet_fruit\\dataset/predict/0.jpeg')# 将图像转为tensortransform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])image = transform(image)# 增加一批次维度image = image.unsqueeze(0)print(image.shape)# 推理 model.eval()with torch.no_grad():image = image.to(device)output = model(image)output = torch.argmax(output, dim=1) # 输出最大值的索引 dim=1表示按行取最大值classes = ['Apple','Orange','Avocado','Kiwi','Mango','Pineapple','Strawberries','Banana','Cherry','Watermelon']print(f'预测结果为:{classes[output.item()]}')
http://www.lryc.cn/news/481937.html

相关文章:

  • 深度学习入门指南:一篇文章全解
  • java ssm 医院病房管理系统 医院管理 医疗病房信息管理 源码 jsp
  • 钩子函数的使用
  • 【Docker】自定义网络:实现容器之间通过域名相互通讯
  • 护理陪护系统|护理陪护软件|陪护软件
  • 苍穹外卖-账号被锁定怎么办?
  • webpack loader全解析,从入门到精通(10)
  • python机器人Agent编程——实现一个本地大模型和爬虫结合的手机号归属地天气查询Agent
  • 【动态规划】斐波那契数列模型总结
  • EasyUI弹出框行编辑,通过下拉框实现内容联动
  • 国产linux系统(银河麒麟,统信uos)使用 PageOffice 实现word文件在线留痕
  • 使用亚马逊 S3 连接器为 PyTorch 和 MinIO 创建地图式数据集
  • 自动化运维:提升效率与稳定性的关键技术实践
  • Google Go编程风格指南-介绍
  • 思科模拟器路由器配置实验
  • 机器学习—选择激活函数
  • [ Linux 命令基础 4 ] Linux 命令详解-文本处理命令
  • Odoo:免费开源的钢铁冶金行业ERP管理系统
  • 33.Redis多线程
  • 【Python】解析 XML
  • 【复平面】-复数相乘的几何性质
  • 为什么ta【给脸不要脸】:利他是一种选择,善良者的自我救赎与智慧策略
  • mysql 配置文件 my.cnf 增加 lower_case_table_names = 1 服务启动不了的原因
  • SIwave:释放 SIwizard 求解器的强大功能
  • 强化学习不愧“顶会收割机”!2大创新思路带你上大分,毕业不用愁!
  • mac 修改启动图图标数量
  • 网站架构知识之Ansible进阶(day022)
  • VMware调整窗口为可以缩小但不改变显示内容的大小
  • Vue 3 中,ref 和 reactive的区别
  • window 利用Putty免密登录远程服务器