一、导入第三方库
import torch
import os
from PIL import Image
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
二、手写数据集准备



#数据集类
class MNISTDataset(Dataset):def __init__(self,files,root_dir,transform=None):self.files=filesself.root_dir=root_dirself.transform=transformself.labels=[]for f in files:parts=f.split("_")p=parts[2].split(".")[0]self.labels.append(int(p))def __len__(self):return len(self.files)def __getitem__(self, idx):img_path=os.path.join(self.root_dir,self.files[idx])img=Image.open(img_path).convert('L')if self.transform:img=self.transform(img)label=self.labels[idx]return img,label
三、CNN模型的pytorch实现
class CNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(1, 10, kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2 = nn.Sequential(nn.Conv2d(10, 20, kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.fc = nn.Sequential(nn.Linear(320, 50),nn.ReLU(),nn.Linear(50, 10))def forward(self, x):batch_size=x.size(0)x=self.conv1(x)x=self.conv2(x)x=x.view(batch_size, -1)x=self.fc(x)return x
四、主程序
if __name__ == '__main__':#路径base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'train_dir=os.path.join(base_dir,"minist_train")test_dir=os.path.join(base_dir,"minist_test")#获取文件夹里图像的名称train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]test_files=[f for f in os.listdir(test_dir) if f.endswith('.jpg')]#数据转换transform=transforms.Compose([transforms.Resize((28, 28)), #统一尺寸transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])])#创建数据集和数据加载器train_dataset=MNISTDataset(train_files,train_dir,transform=transform)test_dataset=MNISTDataset(test_files,test_dir,transform=transform)train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)test_loader=DataLoader(test_dataset,batch_size=64,shuffle=False)model=CNN()criterion=nn.CrossEntropyLoss()optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)#训练函数def train_cnn(epoch):model.train()train_loss = []for epoch_idx in range(epoch):running_loss=0.0for batch_idx,(data, target) in enumerate(train_loader):optimizer.zero_grad()output=model(data)loss=criterion(output,target)loss.backward()optimizer.step()running_loss+=loss.item()if batch_idx%100==0:print(f'Epoch: {epoch_idx + 1}, Batch: {batch_idx}, Loss: {loss.item():.6f}')avg_loss=running_loss/len(train_loader)train_loss.append(avg_loss)print(f'Epoch {epoch_idx + 1}/{epoch}, Average Loss: {avg_loss:.6f}')#损失函数值曲线图plt.figure(figsize=(12, 6))plt.plot(train_loss)plt.title("训练过程中损失函数值变化")plt.xlabel("Epoch")plt.ylabel("损失函数值")plt.grid()#保存loss_plot_path=os.path.join(base_dir,"training_loss_curve.jpg")plt.savefig(loss_plot_path,dpi=300,bbox_inches='tight')plt.close()#对测试集def test_cnn():model.eval()correct=0total=0with torch.no_grad():for data,target in test_loader:outputs=model(data)_, predicted=torch.max(outputs.data, 1)total+=target.size(0)correct+=(predicted==target).sum().item()accuracy=100*correct/totalprint(f'测试集准确率: {accuracy:.2f}%')return accuracy#训练和测试epoch=10train_cnn(epoch)test_accuracy=test_cnn()#显示测试集第一张图像的预测结果model.eval() #进入评估阶段with torch.no_grad():test_img,test_label=test_dataset[0]output=model(test_img.unsqueeze(0)) # 添加批次维度_,pred=torch.max(output.data, 1)plt.imshow(test_img.squeeze(), cmap='gray')plt.title(f"真实数字: {test_label}, 预测数字: {pred.item()}")plt.axis('off')pred_plot_path=os.path.join(base_dir,"first_test_pred.jpg")plt.savefig(pred_plot_path,dpi=300,bbox_inches='tight')plt.close()
五、运行结果
5.1 损失函数曲线图

5.2 测试集第一张图像的预测结果
