当前位置: 首页 > news >正文

经典分类模型回顾16-AlexNet实现垃圾分类(Tensorflow2.0版)

AlexNet是2012年由亚历克斯·克里斯托夫(Alex Krizhevsky)等人提出的一种卷积神经网络结构,它在ImageNet图像识别比赛中获得了第一名,标志着卷积神经网络的崛起。

AlexNet的结构包括8层网络,其中前5层为卷积层,后3层为全连接层。AlexNet的主要特点是使用了ReLU作为激活函数,采用dropout技术避免过拟合问题,使用了最大池化层来降低训练参数数量,采用了LRN(局部响应归一化)来增强泛化能力,还采用了数据增强的方法来扩充训练集。

具体来说,AlexNet的结构如下:

|层|类型|卷积核大小|步长|输出大小|
|---|---|---|---|---|
|1|卷积层|11x11x3|4|55x55x96|
|2|最大池化层|3x3|2|27x27x96|
|3|卷积层|5x5x96|1|27x27x256|
|4|最大池化层|3x3|2|13x13x256|
|5|卷积层|3x3x256|1|13x13x384|
|6|卷积层|3x3x384|1|13x13x384|
|7|卷积层|3x3x384|1|13x13x256|
|8|最大池化层|3x3|2|6x6x256|
|9|全连接层|4096|1|1x1x4096|
|10|全连接层|4096|1|1x1x4096|
|11|全连接层|1000|1|1x1x1000|

其中,第1层卷积层的输入为224x224x3的图片,卷积核大小为11x11,步长为4,共96个卷积核,得到的输出大小为55x55x96。第2层为最大池化层,池化核大小为3x3,步长为2,输出大小为27x27x96。第3层为卷积层,卷积核大小为5x5x96,步长为1,共256个卷积核,得到的输出大小为27x27x256。第4层为最大池化层,池化核大小为3x3,步长为2,输出大小为13x13x256。第5、6、7层为卷积层,卷积核大小分别为3x3x256、3x3x384、3x3x384,共384个卷积核,256个卷积核,256个卷积核,得到的输出大小分别为13x13x384、13x13x384、13x13x256。第8层为最大池化层,池化核大小为3x3,步长为2,输出大小为6x6x256。最后3层为全连接层,第9、10层的输出大小都为1x1x4096,第11层的输出大小为1x1x1000,预测图片的类别。

AlexNet的创新之处在于使用ReLU激活函数来替代传统的sigmoid激活函数,ReLU的计算速度更快,同时解决了梯度消失的问题,使网络的训练更加稳定和有效。此外,AlexNet也是第一个使用dropout来避免过拟合问题的神经网络。数据增强的方法也使训练集得到了扩充,提高了网络的鲁棒性。AlexNet的成功奠定了深度学习在计算机视觉领域的地位,为后续的神经网络研究提供了启示。

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import os# 设置数据目录
train_dir = './garbage_classification/train'
valid_dir = './garbage_classification/test'# 数据预处理
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,      # 缩放像素值到0-1之间rotation_range=40,   # 随机旋转width_shift_range=0.2,  # 随机水平平移height_shift_range=0.2, # 随机竖直平移shear_range=0.2,     # 随机剪切zoom_range=0.2,      # 随机缩放horizontal_flip=True,   # 水平翻转fill_mode='nearest')  # 填充方式valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)# 加载训练集和验证集
train_generator = train_datagen.flow_from_directory(train_dir,target_size=(227, 227),      # AlexNet输入大小batch_size=32,class_mode='categorical')valid_generator = valid_datagen.flow_from_directory(valid_dir,target_size=(227, 227),batch_size=32,class_mode='categorical')# 构建AlexNet模型
model = models.Sequential()
model.add(layers.Conv2D(96, (11, 11), strides=(4, 4), activation='relu', input_shape=(227, 227, 3)))
model.add(layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
model.add(layers.Conv2D(256, (5, 5), strides=(1, 1), activation='relu', padding='same'))
model.add(layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
model.add(layers.Conv2D(384, (3, 3), strides=(1, 1), activation='relu', padding='same'))
model.add(layers.Conv2D(384, (3, 3), strides=(1, 1), activation='relu', padding='same'))
model.add(layers.Conv2D(256, (3, 3), strides=(1, 1), activation='relu', padding='same'))
model.add(layers.MaxPooling2D(pool_size=(3, 3), strides=(2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(4096, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(4096, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(6, activation='softmax'))  # 6种垃圾分类# 打印模型摘要
model.summary()# 配置模型训练过程
model.compile(loss='categorical_crossentropy',optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),metrics=['acc'])# 训练模型
history = model.fit(train_generator,steps_per_epoch=100,epochs=30,validation_data=valid_generator,validation_steps=50)# 保存模型和权重
model.save('garbage_classification.h5')
model.save_weights('garbage_classification_weights.h5')

http://www.lryc.cn/news/38477.html

相关文章:

  • vue3使用vuex
  • Java面向对象:抽象类的学习
  • modbus转profinet网关连接5台台达ME300变频器案例
  • 多校园SaaS运营智慧校园云平台源码 智慧校园移动小程序源码
  • 用DQN实现Atari game(Matlab代码实现)
  • 【JavaSE专栏11】Java的 if 条件语句
  • 【opensea】opensea-js 升级 Seaport v1.4 导致的问题及解决笔记
  • JS语法(扫盲)
  • 归并排序的学习过程(代码实现)
  • add_header重写的坑
  • 跑步耳机入耳好还是不入耳好,最适合运动的蓝牙耳机
  • 深度学习知识点简单概述【更新中】
  • 【编程基础】009.输入两个正整数m和n,求其最大公约数和最小公倍数。
  • Golang错误处理
  • English Learning - L2 语音作业打卡 复习对比 [ɑ:] [æ] Day18 2023.3.10 周五
  • LabVIEW中以编程方式获取VI克隆名称
  • Mysql count(*)的使用原理以及InnoDb的优化策略
  • 一文入门HTML+CSS+JS(样例后续更新)
  • 【STL】Vector剖析及模拟实现
  • 数据库建表的一些技巧
  • 线程(一)
  • [深入理解SSD系列 闪存实战2.1.8] NAND FLASH Multi Plane Program(写)操作_multi plane 为何能提高闪存速度
  • 计算机网络(第八版)——第一章知识总结
  • Linux学习笔记
  • 树与二叉树(概念篇)
  • C++回顾(二十五)—— map/multimap容器
  • 7.3 向量的数量积与向量积
  • Qt静态扫描(命令行操作)
  • 【Hadoop】配置文件
  • python进程池