当前位置: 首页 > news >正文

CNN卷积神经网络预测手写数字的Pytorch实现

一、导入第三方库

import torch
import os
from PIL import Image
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

二、手写数据集准备

#数据集类
class MNISTDataset(Dataset):def __init__(self,files,root_dir,transform=None):self.files=filesself.root_dir=root_dirself.transform=transformself.labels=[]for f in files:parts=f.split("_")p=parts[2].split(".")[0]self.labels.append(int(p))def __len__(self):return len(self.files)def __getitem__(self, idx):img_path=os.path.join(self.root_dir,self.files[idx])img=Image.open(img_path).convert('L')if self.transform:img=self.transform(img)label=self.labels[idx]return img,label

三、CNN模型的pytorch实现

class CNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 10, kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2 = nn.Sequential(nn.Conv2d(10, 20, kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.fc = nn.Sequential(nn.Linear(320, 50),nn.ReLU(),nn.Linear(50, 10))def forward(self, x):batch_size=x.size(0)x=self.conv1(x)x=self.conv2(x)x=x.view(batch_size, -1)x=self.fc(x)return x

四、主程序

if __name__ == '__main__':#路径base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'train_dir=os.path.join(base_dir,"minist_train")test_dir=os.path.join(base_dir,"minist_test")#获取文件夹里图像的名称train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]test_files=[f for f in os.listdir(test_dir) if f.endswith('.jpg')]#数据转换transform=transforms.Compose([transforms.Resize((28, 28)),  #统一尺寸transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])])#创建数据集和数据加载器train_dataset=MNISTDataset(train_files,train_dir,transform=transform)test_dataset=MNISTDataset(test_files,test_dir,transform=transform)train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)test_loader=DataLoader(test_dataset,batch_size=64,shuffle=False)model=CNN()criterion=nn.CrossEntropyLoss()optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)#训练函数def train_cnn(epoch):model.train()train_loss = []for epoch_idx in range(epoch):running_loss=0.0for batch_idx,(data, target) in enumerate(train_loader):optimizer.zero_grad()output=model(data)loss=criterion(output,target)loss.backward()optimizer.step()running_loss+=loss.item()if batch_idx%100==0:print(f'Epoch: {epoch_idx + 1}, Batch: {batch_idx}, Loss: {loss.item():.6f}')avg_loss=running_loss/len(train_loader)train_loss.append(avg_loss)print(f'Epoch {epoch_idx + 1}/{epoch}, Average Loss: {avg_loss:.6f}')#损失函数值曲线图plt.figure(figsize=(12, 6))plt.plot(train_loss)plt.title("训练过程中损失函数值变化")plt.xlabel("Epoch")plt.ylabel("损失函数值")plt.grid()#保存loss_plot_path=os.path.join(base_dir,"training_loss_curve.jpg")plt.savefig(loss_plot_path,dpi=300,bbox_inches='tight')plt.close()#对测试集def test_cnn():model.eval()correct=0total=0with torch.no_grad():for data,target in test_loader:outputs=model(data)_, predicted=torch.max(outputs.data, 1)total+=target.size(0)correct+=(predicted==target).sum().item()accuracy=100*correct/totalprint(f'测试集准确率: {accuracy:.2f}%')return accuracy#训练和测试epoch=10train_cnn(epoch)test_accuracy=test_cnn()#显示测试集第一张图像的预测结果model.eval()  #进入评估阶段with torch.no_grad():test_img,test_label=test_dataset[0]output=model(test_img.unsqueeze(0))  # 添加批次维度_,pred=torch.max(output.data, 1)plt.imshow(test_img.squeeze(), cmap='gray')plt.title(f"真实数字: {test_label}, 预测数字: {pred.item()}")plt.axis('off')pred_plot_path=os.path.join(base_dir,"first_test_pred.jpg")plt.savefig(pred_plot_path,dpi=300,bbox_inches='tight')plt.close()

五、运行结果

5.1 损失函数曲线图

5.2 测试集第一张图像的预测结果

http://www.lryc.cn/news/619959.html

相关文章:

  • games101 第三讲 Transformation(变换)
  • 人工到智能:塑料袋拆垛的自动化革命 —— 迁移科技的实践与创新
  • AI一键抠图软件--Digiarty.AIArty.Image.Matting
  • MySQL数据库知识体系总结 20250813
  • 数据库连接池如何进行空闲管理
  • TeamViewer 以数字化之力,赋能零售企业效率与客户体验双提升
  • “我店模式”:零售转型中的场景化突围
  • 【k8s】k8s pod调度失败原因列表、Pod 完整的状态类型列表
  • TDengine IDMP 基本功能(4. 实时分析)
  • 【金仓数据库产品体验官】_从实践看金仓数据库与 MySQL 的兼容性
  • Java开发主流框架搭配详解及学习路线指南
  • Pytest项目_day14(参数化、数据驱动)
  • VR中image或者文字一直浮现在眼前
  • Flutter 多模块 + 组件化架构设计实践
  • 使用HtmlAgilityPack+PuppeteerSharp+iText7抓取Selenium帮助文档
  • PCIE 配置空间 拓展能力 定义
  • mac环境下安装git并配置密钥等
  • 20250813测试开发岗(凉)面
  • 19. 重载的方法能否根据返回值类型进行区分
  • 完整源码+技术文档!基于Hadoop+Spark的鲍鱼生理特征大数据分析系统免费分享
  • Java Spring框架最新版本及发展史详解(截至2025年8月)-优雅草卓伊凡
  • 【C#】利用数组实现大数数据结构
  • 云电竞盒子对游戏性能有影响吗?
  • 《Python学习之基础语法1:从零开始的编程之旅》
  • 向量相似度计算与Softmax概率分布对比
  • 2025盛夏AI热浪:八大技术浪潮重构数字未来
  • String里常用的方法
  • el-table合并相同名称的列
  • java中在多线程的情况下安全的修改list
  • 基于C#、.net、asp.net的心理健康咨询系统设计与实现/心理辅导系统设计与实现