当前位置: 首页 > news >正文

七、CV_模型微调

七、模型微调

1.微调

  1. 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型(预训练模型)。——修改预训练模型使他适合你的任务,最重要的是修改输出层
  2. 创建一个新的神经网络模块,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  3. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。
  4. 在目标数据集上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的

  • 当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力
  • 根据当前任务数据集的大小来确定微调的网络层
    • 在数据集较小时,隐藏层的参数可以不进行微调
    • 在数据集较大时,可以将隐藏层划开,里面的参数也可以进行变化

2.热狗识别案例

将基于一个小数据集对在ImageNet数据集上训练好的ResNet模型进行微调。该小数据集含有数千张热狗或者其他事物的图像。我们将使用微调得到的模型来识别一张图像中是否包含热狗

import tensorflow as tf
import numpy as np

(1)获取数据集

  • batch_size在读取数据和模型训练的时候均可以进行设置

通过以下方法读取图像文件,该方法以文件夹路径为参数,生成经过图像增强后的结果,并产生batch数据:

flow_from_directory(self,directory, # 目标文件夹路径,对于每一个类对应一个子文件夹,# 该文件夹中任何JPG,PNG,BNP,PPM的图片都可以读取 target_size = (256, 256), # 默认为(256,256),图像将被resize成该尺寸color_mode = 'rgb',classes = None,class_mode = 'categorical',batch_size = 32, # 默认为32shuffle = True, # 是否打乱数据,默认为Trueseed = None,save_to_dir = None)

创建两个tf.keras.preprocessing.image.ImageDataGenerator示例来分别读取训练数据集和测试数据集中的所有图像文件。将训练集图片全部处理为高宽均为224像素的输入。此外,我们对RGB三个颜色通道的数值做标准化。

注意:

class_modelabel 的形状含义
"binary"(32,)二分类(每个样本是 0 或 1)
"categorical"(32, 2)独热编码([1, 0] 或 [0, 1])
"sparse"(32,)多类整数标签(类似 binary)
None无标签仅返回图像,无监督学习时使用

(2)模型构建与训练

实例化预训练数据集(tf.keras.appilcation)------>模型调整(调整输出层,并设置层是否可训练)

  • 我们使用在ImageNet数据集上的预训练模型的ResNet-50作为源模型。这里指定weights = 'imagenet’来自动下载并加载预训练的模型参数。
  • Keras应用程序(keras.applications)是具有预先训练权值的固定框架,该类封装了很多重量级的网络架构

实现时实例化模型架构:

  • 利用tf.keras中的application实现迁移学习
tf.keras.application.ResNet50(include_top = True,  # 是否包含顶层的全连接层(默认为True)weights = 'imagenet', # None代表随机初始化,'imagenet'代表加载在ImageNet上预训练的权重input_tensor = None, # 如果你已经用 tf.keras.Input() 创建了输入层,这里可以传入它;# 一般用于自定义模型结构input_shape = None, # 可选,输入尺寸元组,仅当include_top = False时有效,否则输入形状必须是(224,224,3)(channels_last格式)# 或(3,224,224)(channels_first格式)。它必须为3个输入通道,且高宽必须不小于32pooling = None, # 当 include_top=False 时,是否添加全局池化classes = 1000,**kwargs
)
  • include_top
    • include_top = True, 模型会包含原始 ResNet50 在 ImageNet 上训练的最后三层全连接分类头(avg_poolfc1000 → softmax 输出 1000 类)
    • include_top = False, 就不会包含这些顶层结构,适合迁移学习时接上你自己的分类层。
  • pooling
    • 如果为 None:输出为卷积特征图(feature map),形状类似 (batch, 7, 7, 2048)
    • 'avg':加一层 GlobalAveragePooling2D,输出为 (batch, 2048)
    • 'max':加一层 GlobalMaxPooling2D,输出为 (batch, 2048)
  • classes(输出类别数量)
    • 只有当 **include_top=True** 时有效
    • 用于设置最终全连接层的输出维度。

在该案例中使用resNet50预训练模型架构模型:

# 加载预训练模型
ResNet50 = tf.keras.applications.ResNet50(weights = 'imagenet', input_shape = (224, 224, 3))
# 设置所有层不可训练
for layer in ResNet50.layers:layer.trainable = False# 设置模型
net = tf.keras.models.Squential()
# 预训练模型
net.add(ResNet50)
# 展开
net.add(tf.keras.layers.Flatten())
# 二分类的全连接层
net.add(tf.keras.layers.Dense(2, activation = 'softmax'))

接下来使用之前定义好的ImageGenerator将训练集图片送入ResNet50进行训练

# 模型编译:指定优化器,损失函数,评价指标
net.compile(optimizer = 'adam',loss = 'categorical_crossentropy',metrics = ['accuracy']
)# 模型训练:指定数据,每一个epoch中只运行10个迭代,指定验证数据集
history = net.fit(train_data_gen = True,steps_per_epoch = 10,epochs = 3,validation_data = test_data_gen,  # 验证集validation_step = 10
)
http://www.lryc.cn/news/616350.html

相关文章:

  • 使用快捷键将当前屏幕内容滚动到边缘@首行首列@定位到第一行第一个字符@跳转到4个角落
  • Knuth‘s TwoSum Algorithm 原理详解
  • 每日任务day0810:小小勇者成长记之武器精炼
  • 机器学习 DBScan
  • VUE+SPRINGBOOT从0-1打造前后端-前后台系统-关于我们
  • 人大地平线新国立单目具身导航新范式!MonoDream:基于全景想象的单目视觉语言导航
  • 周学会Matplotlib3 Python 数据可视化-绘制折线图(Lines)
  • python中re模块详细教程
  • 论文阅读:Aircraft Trajectory Prediction Based on Residual Recurrent Neural Networks
  • SupChains团队:化学品制造商 ChampionX 供应链需求预测案例分享(十七)
  • Speaking T2 - Dining Hall to CloseDuring Spring Break
  • 2025华数杯比赛还未完全结束!数模论文可以发表期刊会议
  • Redis一站式指南二:主从模式高效解决分布式系统“单点问题”
  • 安全引导功能及ATF的启动过程(五)
  • 【GPT入门】第44课 检查 LlamaFactory微调Llama3的效果
  • ThreadLocal有哪些内存泄露问题,如何避免?
  • 商业解决方案技术栈总结
  • 洛谷 P2404 自然数的拆分问题-普及-
  • LeetCode - 搜索插入位置 / 排序链表
  • 音视频学习(五十一):AAC编码器
  • 力扣(买卖股票的最佳时机I/II)
  • 面对信号在时频平面打结,VNCMD分割算法深度解密
  • windows的cmd命令【持续更新】
  • 数据库面试题集
  • ADB简介
  • 全面了解机器语言之kmeans
  • UE5多人MOBA+GAS 41、制作一个飞弹,添加准心索敌
  • 【走进Docker的世界】Docker环境搭建
  • Java集合框架、Collection体系的单列集合
  • OpenStack热迁移一直处于迁移中怎么办