当前位置: 首页 > news >正文

13.使用NiN网络进行Fashion-Mnist分类

13.1 NiN网络结构设计

在这里插入图片描述

import torch
from torch import nn
import matplotlib.pyplot as plt
from torchsummary import summarydef nin_block(in_channels, out_channels, kernel_size, strides, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())
model= nn.Sequential(nin_block(1, 96, kernel_size=11, strides=4, padding=0),nn.MaxPool2d(3, stride=2),nin_block(96, 256, kernel_size=5, strides=1, padding=2),nn.MaxPool2d(3, stride=2),nin_block(256, 384, kernel_size=3, strides=1, padding=1),nn.MaxPool2d(3, stride=2),nn.Dropout(0.5),# 标签类别数是10nin_block(384, 10, kernel_size=3, strides=1, padding=1),nn.AdaptiveAvgPool2d((1, 1)),# 将四维的输出转成二维的输出,其形状为(批量大小,10)nn.Flatten())
device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
model.to(device)
summary(model,input_size=(1,224,224),batch_size=1)

在这里插入图片描述

13.2 NiN网络实现Fashion-Mnist分类

import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
plt.rcParams['font.family']=['Times New Roman']
def nin_block(in_channels, out_channels, kernel_size, strides, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):epochs = range(1, len(train_loss_list) + 1)plt.figure(figsize=(4, 3))plt.plot(epochs, train_loss_list, label='Train Loss')plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')plt.xlabel('Epoch')plt.ylabel('Value')plt.title(title)plt.legend()plt.grid(True)plt.tight_layout()plt.show()
def train_model(model,train_data,test_data,num_epochs):train_loss_list = []train_acc_list = []test_acc_list = []for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0#loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")loop1=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")loop2=tqdm(test_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop1:#X=X.reshape(X.shape[0],-1)#print(X.shape)X=X.to(device)y=y.to(device)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数total_samples+=X.shape[0]test_acc_samples=0test_samples=0for X,y in loop2:X=X.to(device)y=y.to(device)#X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数test_samples+=X.shape[0]avg_train_loss=total_loss/total_samplesavg_train_acc=total_acc_sample/total_samplesavg_test_acc=test_acc_samples/test_samplestrain_loss_list.append(avg_train_loss)train_acc_list.append(avg_train_acc)test_acc_list.append(avg_test_acc)print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")plot_metrics(train_loss_list, train_acc_list, test_acc_list)return model
def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)
################################################################################################################
#注意这里从28*28 resize成224*224了
transforms=transforms.Compose([transforms.Resize(224),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])#第一个是mean,第二个是std
train_img=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=transforms,download=True)
train_data=DataLoader(train_img,batch_size=256,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=256,num_workers=4,shuffle=False)
################################################################################################################
model = nn.Sequential(nin_block(1, 32, kernel_size=5, strides=1, padding=2),nn.MaxPool2d(2),nin_block(32, 64, kernel_size=3, strides=1, padding=1),nn.MaxPool2d(2),nin_block(64, 10, kernel_size=3, strides=1, padding=1),nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten())
model.apply(init_weights)
device=torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
#print(device)
model=model.to(device)
optimizer=torch.optim.SGD(model.parameters(),lr=0.1,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
num_epochs=20
model=train_model(model,train_data,test_data,num_epochs)
http://www.lryc.cn/news/586802.html

相关文章:

  • macOS - Chrome 关闭自动更新
  • Python 的 MRO
  • [办公及工程版浏览器]_Google Chrome 138.0.7204.101全屏启动插件
  • es里为什么node和shard不是一对一的关系
  • 香港理工大学实验室定时预约
  • 前端框架状态管理对比:Redux、MobX、Vuex 等的优劣与选择
  • 关于 java:11. 项目结构、Maven、Gradle 构建系统
  • 用 Node.js 构建模块化的 CLI 脚手架工具,从 GitHub 下载远程模板
  • Python 学习之路(十)--常见算法实现原理及解析
  • LabVIEW调用外部DLL
  • [CH582M入门第六步]软件IIC驱动AHT10
  • 【数据结构】图 ,拓扑排序 未完
  • Docker(02) Docker-Compose、Dockerfile镜像构建、Portainer
  • 快速生成 Android 的 Splash 的 9 Patch 图片
  • Docker 搭建本地Harbor私有镜像仓库
  • SpringBoot单元测试类拿不到bean报空指针异常
  • 从架构到代码:飞算JavaAI电商订单管理系统技术解构
  • 决策树的相关理论学习
  • FusionOne HCI 23 超融合实施手册(超聚变超融合)
  • 【C++】多线程同步三剑客介绍
  • 代码随想录算法训练营第十七天
  • 【C++】第十五节—一文详解 | 继承
  • JVM 垃圾收集算法全面解析
  • DC-DC变换器最基本拓扑 -Buck电路和Boost电路
  • ROS2---NodeOptions
  • MacOS使用Multipass快速搭建轻量级k3s集群
  • mac上BRPC的CMakeLists.txt优化:解决Protobuf路径问题
  • TensorFlow深度学习实战(24)——变分自编码器详解与实现
  • Vue 3 动态ref问题
  • 封装---统一封装处理页面标题