当前位置: 首页 > news >正文

人工智能学习51-ResNet训练

人工智能学习概述—快手视频
人工智能学习51-ResNet训练—快手视频
人工智能学习52-ResNet训练—快手视频
人工智能学习53-ResNet训练—快手视频
人工智能学习54-ResNet训练—快手视频
人工智能学习55-ResNet预测 —快手视频
人工智能学习56-ResNet预测 —快手视频
在这里插入图片描述
在这里插入图片描述

ResNet 训练10类猕猴识别示例

#导入ResNet50类库 
from keras.applications.resnet import ResNet50,preprocess_input 
#从keras导入顺序模块Sequential 
from keras import Sequential 
#keras.preprocessing.image 导入图像增强工具 
from keras.preprocessing.image import ImageDataGenerator 
#引入numpy类库,方便矩阵操作 
import numpy as np 
#导入图形处理类库 
import matplotlib.pyplot as plt 
#导入keras.layers 模块 
import keras.layers 
#导入OS模块,方便操作文件与目录 
import os 
#避免多库依赖警告信息 
os.environ['KMP_DUPLICATE_LIB_OK']='True' 
#设置神经网络模型存储目录,当前python源文件所在目录上一级下的saved_models目录 
save_dir = os.path.join(os.getcwd(), '../saved_models') 
#如果目录saved_models不存在,新建此目录 
if not os.path.isdir(save_dir): 
os.makedirs(save_dir) 
#神经网络模块名称 
model_name = 'finetune_res50_trained_model.h5' 
#输入图像高度(单位:像素) 
height = 224 
#输入图像宽度(单位:像素) 
width = 224 
# ResNet50 类库使用已经训练好的模型进行迁移学习,增加一种新的动物,在原模型基础
上训练新模型可以识别新增加的动物,原模型可以识别10种猕猴 
num_classes = 11 
#定义Keras顺序模型Sequential 
res50Model = Sequential() 
#构建新模型,添加第一层为ResNet50层,ResNet50合并为第一层 
res50Model.add(ResNet50( 
include_top=False, 
pooling='avg', 
weights='imagenet' 
)) 
#添加一个全连接层,使其可以识别11类动物,使用激活函数softmax 
#预测输出 
res50Model.add(keras.layers.Dense(num_classes, activation='softmax')) 
#设置网络第一层不参与训练(也就是ResNet50,其已经训练完成) 
res50Model.layers[0].trainable = False 
#编译网络模型,优化器采用梯度下降法,损失函数采用交叉熵 
#统计信息设置为准确度 
res50Model.compile( 
optimizer='sgd', 
loss='categorical_crossentropy', 
metrics=['acc'] 
) 
#模型结构汇总输出 
res50Model.summary() 
#定义训练数据集增强类 
train_datagen = ImageDataGenerator( 
preprocessing_function=preprocess_input, #使用 ResNet50 定义输入函数 
rotation_range=40, # 随机旋转的度数范围,表示图像将随机旋转0到40度 
width_shift_range=0.2, #表示图像在水平上随机移动的范围 
height_shift_range=0.2, # 表示图像在垂直方向上随机移动的范围 
shear_range=0.2, # 随机剪切变换的角度范围,图像将随机剪切0到20度的角度 
zoom_range=0.2, # 随机缩放的范围,图像将随机缩放90%到110% 
horizontal_flip=True, # 是否进行水平翻转 
fill_mode='nearest' #当变换导致某些像素需要被填充时使用的填充方法,'nearest'
表示使用最近的像素进行填充 
) 
#定义验证数据集增强类 
valid_datagen = ImageDataGenerator( 
preprocessing_function=preprocess_input 
) 
#小批量训练模式下每次训练样本数量 
batch_size = 32 
#训练集数据增强生成器 
train_generator = train_datagen.flow_from_directory( 
"../monkey10_species/training/training", #训练图片所在目录 
target_size=(height, width), #图片尺寸大小 
batch_size=batch_size, #每次训练样品批量 
seed=10, # 指定随机数种子,用于洗牌操作的随机性 
shuffle=True, #是否对样品数据洗牌 
class_mode='categorical' #指定标签的类型,可以是"categorical"(多分类问题)、
"binary"(二分类问题)、"sparse"(稀疏标签问题)或"None"(无标签问题) 
) 
#训练集样品总量 
train_num = train_generator.samples 
valid_generator = valid_datagen.flow_from_directory( 
"../monkey10_species/validation/validation", #训练图片所在目录 
target_size=(height, width), #图片尺寸大小 
batch_size=batch_size, #每次训练样品批量 
seed=10, # 指定随机数种子,用于洗牌操作的随机性 
shuffle=False, #是否对样品数据洗牌 
class_mode='categorical' #指定标签的类型,可以是"categorical"(多分类问题)、
"binary"(二分类问题)、"sparse"(稀疏标签问题)或"None"(无标签问题) 
) 
#验证集样品总量 
valid_num = valid_generator.samples 
#开始训练模型,匹配训练集与标注真实数值映射关系 
history = res50Model.fit_generator( 
train_generator, #训练集生成器 
steps_per_epoch=train_num // batch_size, # 定义了一个 epoch 中应抽取的步数(批
次数量) 
epochs=5, #训练次数 
validation_data=valid_generator, #测试集生成器 
validation_steps=valid_num // batch_size 
) 
# 保存模型 
model_path = os.path.join(save_dir, model_name); 
vgg16Model.save(model_path) 
#从history 对象中获取准确度核损失统计信息 
acc = history.history['acc'] #训练集准确度 
val_acc = history.history['val_acc'] #验证集准确度 
loss = history.history['loss'] #训练集损失 
val_loss = history.history['val_loss'] #验证集损失 
Epochs = range(1, len(acc) + 1) 
#训练集准确度曲线 
plt.plot(Epochs, acc, 'bo', label='Train Accuracy') 
#验证集准确度曲线 
plt.plot(Epochs, val_acc, 'b', label='Validation Accuracy') 
plt.title('Train and Validation Accuracy') 
plt.legend() 
#显示图形窗口 
plt.show() 
#训练集损失曲线 
plt.plot(Epochs, loss, 'ro', label='Train Loss') 
#验证集损失曲线 
plt.plot(Epochs, val_loss, 'r', label='Validation Loss') 
plt.title('Train and Validation Loss') 
plt.legend() 
#显示图形窗口 
plt.show() 

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

ResNet 预测10类猕猴示例

#导入ResNet50类库 
from keras.applications.resnet import preprocess_input 
#导入装载模型方法load_model 
from keras.saving.save import load_model 
#导入图像处理库 
from keras.preprocessing import image 
#引入numpy类库,方便矩阵操作 
import numpy as np 
#引入json类库 
import json 
#引入sys类库 
import sys 
#导入OS模块,方便操作文件与目录 
import os 
#避免多库依赖警告信息 
os.environ['KMP_DUPLICATE_LIB_OK']='True' 
#设置神经网络模型存储目录,当前python源文件所在目录上一级下的saved_models目录 
save_dir = os.path.join(os.getcwd(), '../saved_models') 
#如果目录saved_models不存在,新建此目录 
if not os.path.isdir(save_dir): 
os.makedirs(save_dir) 
#神经网络模块名称 
model_name = 'finetune_res50_trained_model.h5' 
#神经网络模块所在目录 
model_path = os.path.join(save_dir, model_name) 
#装载神经网络模型 
model = load_model(model_path) 
#定义输入图片变量 
img_path = None 
#命令行输入参数数组 
arguments = sys.argv[1:2] 
if len(arguments)==0: 
img_path = '../hourse.png' 
else: 
img_path = arguments[0] #第一个参数为图片文件 
#由图片文件名称转载图片数据 
img = image.image_utils.load_img(img_path, target_size=(224, 224)) 
#图片数据转化为数组 
x = image.image_utils.img_to_array(img) 
#扩展图片数组维度,第一维扩展维图片样本数量 
x = np.expand_dims(x, axis=0) 
#由ResNet50 提供数据载入函数装载图片数据 
x = preprocess_input(x) 
#模型预测输入图片的动物分类 
pred = model.predict(x) 
list = pred[0] 
pos = 0 
#显示所有预测分类概率 
for i in list: 
print(pos, '=', i) 
pos = pos + 1 
#显示预测概率最大的分类 
print('argmax()=', pred.argmax()) 

在这里插入图片描述

http://www.lryc.cn/news/573526.html

相关文章:

  • Spring AOP全面详讲
  • Python 爬虫案例(不定期更新)
  • 一,python语法教程.内置API
  • 【知识图谱提取】【阶段总结】【LLM4KGC】LLM4KGC项目提取知识图谱推理部分
  • Linux 内核中 TCP 协议栈的输出实现:tcp_output.c 文件解析
  • 【JAVA】数组的使用
  • 电子电气架构 --- 实时系统评价的概述
  • 基于YOLO的智能车辆检测与记录系统
  • Transformer架构每层详解【代码实现】
  • LangGraph--基础学习(工具调用)
  • 2025zbrush雕刻笔记
  • NW849NX721美光固态闪存NX745NX751
  • 微处理器原理与应用篇---计算机系统的结构、组织与实现
  • 给交叉工具链增加libelf.so
  • 操作系统内核态和用户态--2-系统调用是什么?
  • 嵌入式开发之嵌入式系统架构如何搭建?
  • 【软考高级系统架构论文】论面向服务架构设计及其应用
  • modelscope设置默认模型路径
  • python的校园兼职系统
  • Taro 跨端开发:从调试到发布的完整指南
  • 基于正点原子阿波罗F429开发板的LWIP应用(7)——MQTT
  • 华为OD机试-云短信平台优惠活动-完全背包(JAVA 2024E卷)
  • TodoList 案例(Vue3): 使用Composition API
  • 嵌入式开发之嵌入式系统硬件架构设计时,如何选择合适的微处理器/微控制器?
  • 腾讯云IM即时通讯:开启实时通信新时代
  • 一文详解归并分治算法
  • Python:.py文件如何变成双击可执行的windows程序?(版本1)
  • 深入Java面试:从Spring Boot到微服务
  • Django数据库迁移
  • P1220 关路灯