当前位置: 首页 > news >正文

05 CNN 猴子类别检测

一、数据集下载

kaggle数据集[10 monkey]

二、数据集准备

2.1 指定路径

from tensorflow import keras
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plttrain_dir = '/newdisk/darren_pty/CNN/ten_monkey/training/'
valid_dir = '/newdisk/darren_pty/CNN/ten_monkey/validation/'
label_file = '/newdisk/darren_pty/CNN/ten_monkey/monkey_labels.txt'labels = pd.read_csv(label_file, header=0)
print(labels)

 

2.2 数据增强

# 图片数据生成器  数据增加
train_datagen = keras.preprocessing.image.ImageDataGenerator(rescale = 1. / 255,  #jpg 0-255转变为 0-1rotation_range = 40,  #图片翻转width_shift_range = 0.2,  # 移动height_shift_range = 0.2, # 移动shear_range = 0.2, #裁剪zoom_range = 0.2, #缩放比例horizontal_flip = True,  #翻转vertical_flip = True,fill_mode = 'nearest' #填充模式
)

三、从数据集中生成数据

height = 128
width = 128
channels = 3
batch_size = 32
num_classes = 10train_generator = train_datagen.flow_from_directory(train_dir,target_size = (height, width),batch_size = batch_size,shuffle = True,seed = 7,class_mode = 'categorical')valid_datagen = keras.preprocessing.image.ImageDataGenerator(rescale = 1. / 255
)
valid_generator = valid_datagen.flow_from_directory(valid_dir,target_size = (height, width),batch_size = batch_size,shuffle = True,seed = 7,class_mode = 'categorical')
print(train_generator.samples)
print(valid_generator.samples)

Found 1098 images belonging to 10 classes.
Found 272 images belonging to 10 classes.
1098
272

四、模型

train_num = train_generator.samples
valid_num = valid_generator.samplesx, y = train_generator.next()
print(x.shape, y.shape)
print(y)model = keras.models.Sequential()
# 卷积
model.add(keras.layers.Conv2D(filters = 32,kernel_size = 3,padding = 'same',activation='relu',# batch_size, height, width, channelsinput_shape=(128, 128, 3)))model.add(keras.layers.Conv2D(filters = 32,kernel_size = 3,padding = 'same',activation='relu'))
# 池化
model.add(keras.layers.MaxPool2D()) #model.add(keras.layers.Conv2D(filters = 64,kernel_size = 3,padding = 'same',activation='relu'))
model.add(keras.layers.Conv2D(filters = 64,kernel_size = 3,padding = 'same',activation='relu'))
# 池化
model.add(keras.layers.MaxPool2D())
model.add(keras.layers.Conv2D(filters = 128,kernel_size = 3,padding = 'same',activation='relu'))
model.add(keras.layers.Conv2D(filters = 128,kernel_size = 3,padding = 'same',activation='relu'))
# 池化, 向下取整
model.add(keras.layers.MaxPooling2D())model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(512, activation='relu'))
model.add(keras.layers.Dense(256, activation='relu'))
model.add(keras.layers.Dense(10, activation='softmax'))model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])print(model.summary())

五、训练

history = model.fit(train_generator,steps_per_epoch = train_num // batch_size,epochs = 10,validation_data = valid_generator,validation_steps = valid_num // batch_size)

http://www.lryc.cn/news/162288.html

相关文章:

  • 【C#】关于Array.Copy 和 GC
  • Vue前端框架08 Vue框架简介、VueAPI风格、模板语法、事件处理、数组变化侦测
  • WebStorm使用PlantUML
  • Python做批处理,给安卓设备安装应用和传输图片
  • 如何获取springboot中所有的bean
  • 大数据技术之Hadoop:HDFS存储原理篇(五)
  • 用C语言实现牛顿摆控制台动画
  • 如何自己开发一个前端监控SDK
  • node.js笔记
  • mysql 增量备份与恢复使用详解
  • 9月5日上课内容 第一章 NoSQL之Redis配置与优化
  • QT 第四天
  • nrf52832 GPIO输入输出设置
  • MyBatis 动态 SQL 实践教程
  • CSS 斜条纹进度条
  • JavaScript(1)每天10个小知识点
  • scanf和scanf_s函数详解
  • 基于SSM的在线购物系统
  • 认识JVM的内存模型
  • Java8实战-总结19
  • 论文浅尝 | 训练语言模型遵循人类反馈的指令
  • 【云计算网络安全】解析DDoS攻击:工作原理、识别和防御策略 | 文末送书
  • 64位Linux系统上安装64位Oracle10gR2及Oracle11g所需的依赖包
  • Unity InputSystem 基础使用之鼠标交互
  • 《算法竞赛·快冲300题》每日一题:“二进制数独”
  • CnosDB 签约京清能源,助力分布式光伏发电解决监测系统难题。
  • 汇编:lea 需要注意的一点
  • SQL语言的分类:DDL(数据库、表的增、删、改)、DML(数据的增、删、改)
  • 微信小程序精准扶贫数据收集小程序平台设计与实现
  • PostgreSQL 流复制搭建