当前位置: 首页 > news >正文

基于pytorch的手写数字识别-训练+使用

import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoadermatplotlib.use('tkAgg')# 设置图形配置
config = {"font.family": 'serif',"mathtext.fontset": 'stix',"font.serif": ['SimSun'],'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)def mymap(labels):return np.where(labels < 10, labels, 0)# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型
my_nn = torch.nn.Sequential(torch.nn.Linear(400, 128),torch.nn.Sigmoid(),torch.nn.Linear(128, 256),torch.nn.Sigmoid(),torch.nn.Linear(256, 512),torch.nn.Sigmoid(),torch.nn.Linear(512, 10)
).to(device)# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签# 进行预测
with torch.no_grad():  # 禁用梯度计算predictions = my_nn(sample_images)predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签# 绘制图像
plt.figure(figsize=(10, 10))
for i in range(50):plt.subplot(10, 5, i + 1)  # 10行5列的子图plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像plt.title(f'Predicted: {predicted_labels[i]}', fontsize=8)plt.axis('off')  # 关闭坐标轴plt.tight_layout()  # 调整子图间距
plt.show()

Iteration 0, Loss: 0.8472495079040527
Iteration 20, Loss: 0.014742681756615639
Iteration 40, Loss: 0.00011596851982176304
Iteration 60, Loss: 9.278443030780181e-05
Iteration 80, Loss: 1.3701709576707799e-05
Iteration 100, Loss: 5.019319928578625e-07
Iteration 120, Loss: 0.0
Iteration 140, Loss: 0.0
Iteration 160, Loss: 1.2548344585638915e-08
Iteration 180, Loss: 1.700657230685465e-05
预测准确率: 100.00%

下面使用已经训练好的模型,进行再次测试:

import pandas as pd
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoadermatplotlib.use('tkAgg')# 设置图形配置
config = {"font.family": 'serif',"mathtext.fontset": 'stix',"font.serif": ['SimSun'],'axes.unicode_minus': False
}
matplotlib.rcParams.update(config)def mymap(labels):return np.where(labels < 10, labels, 0)# 数据加载
path = "d:\\JD\\Documents\\大学等等等\\自学部分\\机器学习自学画图\\手写数字识别\\ex3data1.xlsx"
data = pd.read_excel(path)
data = np.array(data, dtype=np.float32)
x = data[:, :-1]
labels = data[:, -1]
labels = mymap(labels)# 转换为Tensor
x = torch.tensor(x, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.long)# 创建Dataset和Dataloader
dataset = TensorDataset(x, labels)
train_loader = DataLoader(dataset, batch_size=20, shuffle=True)# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义模型
my_nn = torch.nn.Sequential(torch.nn.Linear(400, 128),torch.nn.Sigmoid(),torch.nn.Linear(128, 256),torch.nn.Sigmoid(),torch.nn.Linear(256, 512),torch.nn.Sigmoid(),torch.nn.Linear(512, 10)
).to(device)# 加载预训练模型
my_nn.load_state_dict(torch.load('model.pth'))
my_nn.eval()  # 切换至评估模式# 准备选取数据进行预测
sample_indices = np.random.choice(len(dataset), 50, replace=False)  # 随机选择50个样本
sample_images = x[sample_indices].to(device)  # 选择样本并移动到GPU
sample_labels = labels[sample_indices].numpy()  # 真实标签# 进行预测
with torch.no_grad():  # 禁用梯度计算predictions = my_nn(sample_images)predicted_labels = torch.argmax(predictions, dim=1).cpu().numpy()  # 获取预测的标签plt.figure(figsize=(16, 10))
for i in range(20):plt.subplot(4, 5, i + 1)  # 4行5列的子图plt.imshow(sample_images[i].cpu().reshape(20, 20), cmap='gray')  # 还原为20x20图像plt.title(f'True: {sample_labels[i]}, Pred: {predicted_labels[i]}', fontsize=12)  # 标题中显示真实值和预测值plt.axis('off')  # 关闭坐标轴plt.tight_layout()  # 调整子图间距
plt.show()

http://www.lryc.cn/news/455058.html

相关文章:

  • SpringBoot接收前端传递参数
  • 【LeetCode周赛】第 418 场
  • Android学习7 -- NDK2 -- 几个例子
  • 问:说说JVM不同版本的变化和差异?
  • 计算机毕业设计 基于Python的社交音乐分享平台的设计与实现 Python+Django+Vue 前后端分离 附源码 讲解 文档
  • 51单片机的水位检测系统【proteus仿真+程序+报告+原理图+演示视频】
  • Python和R及Julia妊娠相关疾病生物剖析算法
  • Web安全 - 重放攻击(Replay Attack)
  • Python项目文档生成常用工具对比
  • 教育领域的技术突破:SpringBoot系统实现
  • RabbitMQ入门3—virtual host参数详解
  • 【Nacos入门到实战十四】Nacos配置管理:集群部署与高可用策略
  • UE5+ChatGPT实现3D AI虚拟人综合实战
  • [图形学]smallpt代码详解(2)
  • vmstat命令:系统性能监控
  • linux部署NFS和autofs自动挂载
  • WPF RadioButton 绑定boolean值
  • 2024 ciscn WP
  • 代码随想录--字符串--重复的子字符串
  • No.5 笔记 | 网络端口协议概览:互联网通信的关键节点
  • 手机地址IP显示不对?别急,这里有解决方案
  • 人工智能对未来工作影响的四种可能性
  • SpringBoot+ElasticSearch7.12.1+Kibana7.12.1简单使用
  • RESTful风格接口+Swagger生成Web API文档
  • 性能测试学习2:常见的性能测试策略(基准测试/负载测试/稳定性测试/压力测试/并发测试)
  • 【C++】—— 继承(上)
  • 【2024保研经验帖】东南大学计算机学院夏令营
  • dz论坛可可积分商城插件价值399元
  • python的extend和append
  • 贪心算法相关知识