当前位置: 首页 > news >正文

VGG16训练和测试Fashion和CIFAR10

CIFAR10 epoch=50

Test Accuracy: 91.77%

VGG网络结构

VGG16(D)网络结构 5个block+5个Maxpool+3个FC+soft-max

VGG网络参数(标准 ImageNet 版 VGG16)

第一个VGG Block层(两个卷积层,两个ReLu,一个池化)

self.block1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2)
)

第二个VGG Block层(两个卷积层,两个ReLu,一个池化)

self.block2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))

第三个VGG Block层(三个卷积层,三个ReLu,一个池化)

        self.block3 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))

第四个VGG Block层(三个卷积层,三个ReLu,一个池化)

        self.block4 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))

第五个VGG Block层(三个卷积层,三个ReLu,一个池化)

        self.block5 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))

线性分类器层(flatten,linear,Relu,Dropout)

        self.classifier = nn.Sequential(nn.Flatten(),nn.Linear(in_features=7*7*512, out_features=4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(in_features=4096, out_features=4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(in_features=4096, out_features=10),)

前向传播代码

    def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.classifier(x)return x

kaiming初始化权重函数

# ===== Kaiming 初始化函数 =====
def init_weights(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.constant_(m.bias, 0)

主函数代码

if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = VGG16().to(device)print(summary(model, (1, 224, 224)))

完整代码

import torch
import torch.nn as nn
from torchsummary import summaryclass VGG16(nn.Module):def __init__(self):super().__init__()self.block1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))self.block2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))self.block3 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))self.block4 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))self.block5 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2, 2))self.classifier = nn.Sequential(nn.Flatten(),nn.Linear(7*7*512, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(4096, 10),)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.classifier(x)return x# ===== Kaiming 初始化函数 =====
def init_weights(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.constant_(m.bias, 0)if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = VGG16().to(device)# 应用 Kaiming 初始化model.apply(init_weights)# 打印模型结构summary(model, (1, 224, 224))
运行结果:Total params: 134,300,362

VGG网络参数(适用于CIFAR10的 VGG16)

前面的网络参数太大,一般适用于ImageNet数据集,而CIFAR10数据集只有32*32*3,因此可以重新设计模型参数,并且可以加入BatchNorm提高收敛速度和稳定性。

http://www.lryc.cn/news/612280.html

相关文章:

  • 利用C++11和泛型编程改进原型模式
  • 学习 Android(十五)NDK进阶及性能优化
  • 功能安全和网络安全的综合保障流程
  • 分布式事务Seata、LCN的原理深度剖析
  • vue中reactive()和ref()的用法
  • selenium操作指南
  • 状态模式及优化
  • 【机器学习篇】02day.python机器学习篇Scikit-learn基础操作
  • 如何解决pip安装报错ModuleNotFoundError: No module named ‘gensim’问题
  • Session 和 JWT(JSON Web Token)
  • python:非常流行和重要的Python机器学习库scikit-learn 介绍
  • 毕业设计选题推荐之基于Spark的在线教育投融数据可视化分析系统 |爬虫|大数据|大屏|预测|深度学习|数据分析|数据挖掘
  • Packets Frames 数据包和帧
  • 大数据存储域——Hive数据仓库工具
  • 数据结构---二级指针(应用场景)、内核链表、栈(系统栈、实现方式)、队列(实现方式、应用)
  • STM32学习记录--Day8
  • 键帽(dp)
  • 【数字图像处理系列笔记】Ch03:图像的变换
  • Redis Redis 常见数据类型
  • 高等数学(工本)----00023 速记宝典
  • JAVA高级编程第八章
  • windows系统创建ubuntu系统
  • Python与自动化运维:构建智能IT基础设施的终极方案
  • 第七章课后综合练习
  • 学习日志29 python
  • 达梦数据库数据守护集群启动与关闭标准流程
  • 对接钉钉审批过程记录(C#版本)
  • 什么是逻辑外键?我们要怎么实现逻辑外键?
  • IDEA 2025下载安装教程【超详细】保姆级图文教程(附安装包)
  • 2 SpringBoot项目对接单点登录说明