当前位置: 首页 > news >正文

14.使用GoogleNet/Inception网络进行Fashion-Mnist分类

14.1 GoogleNet网络结构设计

在这里插入图片描述
在这里插入图片描述

import torch
from torch import nn
from torch.nn import functional as F
from torchsummary import summary
class Inception(nn.Module):def __init__(self, in_channels,c1,c2,c3,c4,**kwargs):super(Inception,self).__init__(**kwargs)#第一条路线:1*1的卷积层self.p1_1=nn.Conv2d(in_channels,c1,kernel_size=1)#第二条路线:1*1的卷积层+3*3的卷积层self.p2_1=nn.Conv2d(in_channels,c2[0],kernel_size=1)self.p2_2=nn.Conv2d(c2[0],c2[1],kernel_size=3,padding=1)#第三条路线:1*1的卷积层+5*5的卷积层self.p3_1=nn.Conv2d(in_channels,c3[0],kernel_size=1)self.p3_2=nn.Conv2d(c3[0],c3[1],kernel_size=5,padding=2)#第四条路线:3*3Maxpool+1*1 convsself.p4_1=nn.MaxPool2d(kernel_size=3,stride=1,padding=1)self.p4_2=nn.Conv2d(in_channels,c4,kernel_size=1)def forward(self,x):p1=F.relu(self.p1_1(x))#第一层p2=F.relu(self.p2_2(F.relu(self.p2_1(x))))p3=F.relu(self.p3_2(F.relu(self.p3_1(x))))p4=F.relu(self.p4_2(self.p4_1(x)))ft=torch.concat((p1,p2,p3,p4),dim=1)return ft
#组建googlenet
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b2=nn.Sequential(nn.Conv2d(64,64,kernel_size=1),nn.ReLU(),nn.Conv2d(64,192,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b3=nn.Sequential(Inception(192,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b4=nn.Sequential(Inception(480,192,(96,208),(16,48),64),Inception(512,160,(112,224),(24,64),64),Inception(512,128,(128,256),(24,64),64),Inception(512,112,(144,288),(32,64),64),Inception(528,256,(160,320),(32,128),128),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b5=nn.Sequential(Inception(832,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten())
device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model=nn.Sequential(b1,b2,b3,b4,b5,nn.Linear(480,10)).to(device)
summary(model,input_size=(1,224,224),batch_size=1)

在这里插入图片描述

14.2 GoogleNet网络实现Fashion-Mnist分类

import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
plt.rcParams['font.family']=['Times New Roman']
class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):epochs = range(1, len(train_loss_list) + 1)plt.figure(figsize=(4, 3))plt.plot(epochs, train_loss_list, label='Train Loss')plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')plt.xlabel('Epoch')plt.ylabel('Value')plt.title(title)plt.legend()plt.grid(True)plt.tight_layout()plt.show()
def train_model(model,train_data,test_data,num_epochs):train_loss_list = []train_acc_list = []test_acc_list = []for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop:#X=X.reshape(X.shape[0],-1)#print(X.shape)X=X.to(device)y=y.to(device)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数total_samples+=X.shape[0]test_acc_samples=0test_samples=0for X,y in test_data:X=X.to(device)y=y.to(device)#X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数test_samples+=X.shape[0]avg_train_loss=total_loss/total_samplesavg_train_acc=total_acc_sample/total_samplesavg_test_acc=test_acc_samples/test_samplestrain_loss_list.append(avg_train_loss)train_acc_list.append(avg_train_acc)test_acc_list.append(avg_test_acc)print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")plot_metrics(train_loss_list, train_acc_list, test_acc_list)return model
def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)
class Inception(nn.Module):def __init__(self, in_channels,c1,c2,c3,c4,**kwargs):super(Inception,self).__init__(**kwargs)#第一条路线:1*1的卷积层self.p1_1=nn.Conv2d(in_channels,c1,kernel_size=1)#第二条路线:1*1的卷积层+3*3的卷积层self.p2_1=nn.Conv2d(in_channels,c2[0],kernel_size=1)self.p2_2=nn.Conv2d(c2[0],c2[1],kernel_size=3,padding=1)#第三条路线:1*1的卷积层+5*5的卷积层self.p3_1=nn.Conv2d(in_channels,c3[0],kernel_size=1)self.p3_2=nn.Conv2d(c3[0],c3[1],kernel_size=5,padding=2)#第四条路线:3*3Maxpool+1*1 convsself.p4_1=nn.MaxPool2d(kernel_size=3,stride=1,padding=1)self.p4_2=nn.Conv2d(in_channels,c4,kernel_size=1)def forward(self,x):p1=F.relu(self.p1_1(x))#第一层p2=F.relu(self.p2_2(F.relu(self.p2_1(x))))p3=F.relu(self.p3_2(F.relu(self.p3_1(x))))p4=F.relu(self.p4_2(self.p4_1(x)))ft=torch.concat((p1,p2,p3,p4),dim=1)return ft
#组建googlenet
b1=nn.Sequential(nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b2=nn.Sequential(nn.Conv2d(64,64,kernel_size=1),nn.ReLU(),nn.Conv2d(64,192,kernel_size=3,padding=1),nn.ReLU(),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b3=nn.Sequential(Inception(192,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b4=nn.Sequential(Inception(480,192,(96,208),(16,48),64),Inception(512,160,(112,224),(24,64),64),Inception(512,128,(128,256),(24,64),64),Inception(512,112,(144,288),(32,64),64),Inception(528,256,(160,320),(32,128),128),nn.MaxPool2d(kernel_size=3,stride=2,padding=1))
b5=nn.Sequential(Inception(832,64,(96,128),(16,32),32),Inception(256,128,(128,192),(32,96),64),nn.AdaptiveAvgPool2d((1,1)),nn.Flatten())
device=torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
model=nn.Sequential(b1,b2,b3,b4,b5,nn.Linear(480,10)).to(device)
transforms=transforms.Compose([transforms.Resize(96),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])#第一个是mean,第二个是std
train_img=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=transforms,download=True)
train_data=DataLoader(train_img,batch_size=128,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=128,num_workers=4,shuffle=False)
################################################################################################################
model.apply(init_weights)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_data,test_data,num_epochs=15)
################################################################################################################

在这里插入图片描述

http://www.lryc.cn/news/586638.html

相关文章:

  • 4. 观察者模式
  • Java行为型模式---观察者模式
  • Typecho分类导航栏开发指南:从基础到高级实现
  • 低代码引擎核心技术:OneCode常用动作事件速查手册及注解驱动开发详解
  • Pytorch实现感知器并实现分类动画
  • 深入理解观察者模式:构建松耦合的交互系统
  • 为什么玩游戏用UDP,看网页用TCP?
  • 【C++详解】STL-priority_queue使用与模拟实现,仿函数详解
  • 信息收集实战
  • 【读书笔记】《C++ Software Design》第九章:The Decorator Design Pattern
  • 设计模式:软件开发的高效解决方案(单例、工厂、适配器、代理)
  • 基于无人机 RTK 和 yolov8 的目标定位算法
  • 一文认识并学会c++模板(初阶)
  • AI 助力编程:Cursor Vibe Coding 场景实战演示
  • 基于 Redisson 实现分布式系统下的接口限流
  • 牛客网50题
  • 【C/C++】编译期计算能力概述
  • [Python] -实用技巧篇1-用一行Python代码搞定日常任务
  • python-range函数
  • 校园幸运抽(抽奖系统)测试报告
  • 第七章应用题
  • HT8313功放入门
  • HashMap的原理
  • 数据结构与算法之美:线索二叉树
  • 蒙特卡洛树搜索方法实践
  • 蓝牙调试抓包工具--nRF Connect移动端 使用详细总结
  • 生成式对抗网络(GAN)模型原理概述
  • Java生产带文字、带边框的二维码
  • 牛客:HJ19 简单错误记录[华为机考][字符串]
  • 009 ST表:静态区间最值的极致优化