当前位置: 首页 > news >正文

机器学习入门【经典的CIFAR10分类】

模型

神经网络采用下图
在这里插入图片描述

我使用之后发现迭代多了之后一直最高是正确率65%左右,然后我自己添加了一些Relu激活函数和正则化,现在正确率可以有80%左右。

模型代码

import torch
from torch import nnclass YmModel(nn.Module):def __init__(self):super(YmModel, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(64),nn.ReLU(),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, 64),nn.ReLU(),nn.Dropout(0.5),nn.Linear(64, 10),)def forward(self, x):return self.model(x)

训练

有一点要说明的是,数据集中并没有验证集,你可以从训练集扣个1w张出来

import torch
import torchvision
from torchvision import transformsfrom models.YMModel import YmModel
from torch.utils.data import DataLoadertransform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])# 数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform_train, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)
print(len(train_loader), len(test_loader))print(len(train_dataset), len(test_dataset))model = YmModel()
#迭代次数
train_epochs = 300
#优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 损失函数
loss_fn = torch.nn.CrossEntropyLoss()train_epochs_step = 0
best_accuracy = 0.for epoch in range(train_epochs):model.train()print(f'Epoch is {epoch}')for images, labels in train_loader:outputs = model(images)loss = loss_fn(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()if train_epochs_step % 100 == 0:print(f'Train_Epoch is {train_epochs_step}\t Loss is {loss.item()}')train_epochs_step += 1train_epochs_step = 0with torch.no_grad():loss_running_total = 0.acc_running_total = 0.for images, labels in test_loader:outputs = model(images)loss = loss_fn(outputs, labels)loss_running_total += loss.item()acc_running_total += (outputs.argmax(1) == labels).sum().item()acc_running_total /= len(test_dataset)if acc_running_total > best_accuracy:best_accuracy = acc_running_totaltorch.save(model.state_dict(), './best_model.pth')print('accuracy is {}'.format(acc_running_total))print('total loss is {}'.format(loss_running_total))print('best accuracy is {}'.format(best_accuracy))

验证

import osimport numpy as np
import torch
import torchvision
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transformsfrom models.TestColor import TextColor
from models.YMModel import YmModeltest_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)classes = ('airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')
model = YmModel()model.load_state_dict(torch.load('best_model.pth'))model.eval()
with torch.no_grad():correct = 0.for images, labels in test_loader:outputs = model(images)_, predicted = torch.max(outputs, 1)correct += (predicted == labels).sum().item()print('Accuracy : {}'.format(100 * correct / len(test_dataset)))
folder_path = './images'
files_names = os.listdir(folder_path)
transform_test = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),
])for file_name in files_names:image_path = os.path.join(folder_path, file_name)image = Image.open(image_path)image = transform_test(image)image = np.reshape(image, [1, 3, 32, 32])output = model(image)_, predicted = torch.max(output, 1)source_name = os.path.splitext(file_name)[0]predicted_class = classes[predicted.item()]colors = TextColor.GREEN if predicted_class == source_name else TextColor.REDprint(f"Source is {TextColor.BLUE}{source_name}{TextColor.RESET}, and predicted is {colors}{predicted_class}{TextColor.RESET}")

结果

TextColor是自定义字体颜色的类,image中就是自己的图片。
结果如下:测试集的正确率有82.7%

在这里插入图片描述

http://www.lryc.cn/news/402569.html

相关文章:

  • 01 安装
  • AI 模型本地推理 - YYPOLOE - Python - Windows - GPU - 吸烟检测(目标检测)- 有配套资源直接上手实现
  • 全国媒体邀约,主流媒体到场出席采访报道
  • 计算机视觉8 图像增广
  • Transformer中的自注意力是怎么实现的?
  • LabVIEW鼠标悬停在波形图上的曲线来自动显示相应点的坐标
  • 操作系统发展简史(Unix/Linux 篇 + DOS/Windows 篇)+ Mac 与 Microsoft 之风云争霸
  • 钡铼分布式 IO 系统 OPC UA边缘计算耦合器BL205
  • 实现了一个心理测试的小程序,微信小程序学习使用问题总结
  • vue是如何进行监听数据变化的?vue2和vue3分别是什么?vue3为什么要更换?
  • 数据结构day3
  • 免费的数字孪生平台助力产业创新,让新质生产力概念有据可依
  • mtsys2 编译 qemu 记录
  • 【Python数据分析】数据分析三剑客:NumPy、SciPy、Matplotlib中常用操作汇总
  • STM32智能家居电力管理系统教程
  • C# 邮件发送
  • Kotlin 协程简化回调
  • 帝王蝶算法(EBOA)及Python和MATLAB实现
  • 【学术会议征稿】第六届信息与计算机前沿技术国际学术会议(ICFTIC 2024)
  • PHP MySQL 读取数据
  • 点亮 LED-I.MX6U嵌入式Linux C应用编程学习笔记基于正点原子阿尔法开发板
  • 从0到1搭建数据中台(4):neo4j初识及安装使用
  • 【20】读感 - 架构整洁之道(二)
  • js vue axios post 数组请求参数获取转换, 后端go参数解析(gin框架)全流程示例
  • 揭秘郭采洁浪漫升级
  • 数据结构(Java):力扣牛客 二叉树面试OJ题(一)
  • 在国产芯片上实现YOLOv5/v8图像AI识别-【1.3】YOLOv5的介绍及使用(训练、导出)更多内容见视频
  • 逻辑门的题目怎么做?
  • CentOS 7报错:yum命令报错 “ Cannot find a valid baseurl for repo: base/7/x86_6 ”
  • 51单片机STC89C52RC——18.1 HC-SR04超声波测距