当前位置: 首页 > news >正文

【YOLOv5进阶】——引入注意力机制-以SE为例

声明:笔记是做项目时根据B站博主视频学习时自己编写,请勿随意转载!

一、站在巨人的肩膀上

SE模块即Squeeze-and-Excitation 模块,这是一种常用于卷积神经网络中的注意力机制!!

借鉴代码的代码链接如下:

注意力机制-SEicon-default.png?t=N7T8https://github.com/ZhugeKongan/Attention-mechanism-implementation

需要model里面的SE_block.py文件

# -*- coding: UTF-8 -*-
"""
SE structure"""import torch.nn as nn  # 导入PyTorch的神经网络模块  
import torch.nn.functional as F  # 导入PyTorch的神经网络功能函数模块  class SE(nn.Module):  # 定义一个名为SE的类,该类继承自PyTorch的nn.Module,表示一个神经网络模块  def __init__(self, in_chnls, ratio):  # 初始化函数,in_chnls表示输入通道数,ratio表示压缩比率  super(SE, self).__init__()  # 调用父类nn.Module的初始化函数  # 使用AdaptiveAvgPool2d将输入的空间维度压缩为1x1,即全局平均池化  self.squeeze = nn.AdaptiveAvgPool2d((1, 1))  # 使用1x1卷积将通道数压缩为原来的1/ratio,实现特征压缩  self.compress = nn.Conv2d(in_chnls, in_chnls // ratio, 1, 1, 0)  # 使用1x1卷积将通道数扩展回原来的in_chnls,实现特征激励  self.excitation = nn.Conv2d(in_chnls // ratio, in_chnls, 1, 1, 0)  def forward(self, x):  # 定义前向传播函数  out = self.squeeze(x)  # 对输入x进行全局平均池化  out = self.compress(out)  # 对池化后的输出进行特征压缩  out = F.relu(out)  # 对压缩后的特征进行ReLU激活  out = self.excitation(out)  # 对激活后的特征进行特征激励  # 对激励后的特征应用sigmoid函数,然后与原始输入x进行逐元素相乘,实现特征重标定  return x*F.sigmoid(out)

代码后面有附注的注释(GPT解释的,很好用),理解即可。对于使用者来说,重要关注点还是它的输入通道、输出通道、需要传入的参数等!!这个函数整体传入in_chnls, ratio两个参数。


二、开始修改网络结构

与上节的C2f修改基本流程一致,但稍有不同

  • model/common.py加入新增的SE网络结构,直接复制粘贴如下,这里加在了上节的C2f之前:

上面说到这个函数整体传入in_chnls, ratio两个参数!!


  • model/yolo.py设定网络结构的传参细节

上期的C2f模块之所以可以参照原本存在的C3模块属性,是因为两者相似,但这里的SE模块就不可简单的在C3x后加SE,而是需要在下面加入一段elif代码:

         elif m is SE:c1 = ch[f]c2 = args[0]if c2 != no:  # if not outputc2 = make_divisible(c2 * gw, 8)args = [c1, args[1]]

当新引入的模块中存在输入输出维度时,需要使用gw调整输出维度!!


  • model/yolov5s.yaml设定现有模型结构配置文件

老样子,复制一份新的配置文件命名为yolov5s-se.yaml。首先需要在backbone的最后加上SE模块(相当于多了一层为第10层);其次考虑到backbone里多了一层,且在head里的输入层来源不止上一层(-1)一个,所以输入层来源大于等于第10层的都需要改为往后递推+1层。下图左边为原始的yaml配置文件,右侧为修改后的:

当yaml文件引入新的层后,需要修改模型结构的from参数(上期是将C3替换为C2f模块,所以不涉及这一点)!!


  • train.py训练时指定模型结构配置文件

这次将parse_model函数里的第二个参数cfg改为yolov5s-se.yaml即可,运行train.py开始训练!!

可见训练时第10层已经引入了SE注意力机制模块:

100次迭代后结果如下,结果保存在runs\train\exp12文件夹,文件夹里有很多指标曲线可对比分析:


 往期精彩

STM32专栏(9.9)icon-default.png?t=N7T8http://t.csdnimg.cn/A3BJ2

OpenCV-Python专栏(9.9)icon-default.png?t=N7T8http://t.csdnimg.cn/jFJWe

AI底层逻辑专栏(9.9)icon-default.png?t=N7T8http://t.csdnimg.cn/6BVhM

机器学习专栏(免费)icon-default.png?t=N7T8http://t.csdnimg.cn/ALlLlSimulink专栏(免费)icon-default.png?t=N7T8http://t.csdnimg.cn/csDO4电机控制专栏(免费)icon-default.png?t=N7T8http://t.csdnimg.cn/FNWM7

http://www.lryc.cn/news/389798.html

相关文章:

  • 【C++题解】1456. 淘淘捡西瓜
  • 用Python读取Word文件并提取标题
  • Windows编程上
  • BiTCN-Attention一键实现回归预测+8张图+特征可视化图!注意力全家桶再更新!
  • zoom缩放问题(关于ElementPlus、Echarts、Vue3draggable等组件偏移问题)
  • 【后端面试题】【中间件】【NoSQL】MongoDB的配置服务器、复制机制、写入语义和面试准备
  • 视频监控汇聚平台LntonCVS视频监控业务平台具体有哪些功能?
  • 我不小心把生产的数据改错了!同事帮我用MySQL的BinLog挽回了罚款
  • Windows系统安装NVM,实现Node.js多版本管理
  • k8s部署单节点redis
  • 云微客矩阵系统:如何利用智能策略引领营销新时代?
  • 嵌入式Linux系统编程 — 6.3 kill、raise、alarm、pause函数向进程发送信号
  • Swoole实践:如何使用协程构建高性能爬虫
  • 基于人脸68特征点识别的美颜算法(一) 大眼算法 C++
  • 算法金 | 欧氏距离算法、余弦相似度、汉明、曼哈顿、切比雪夫、闵可夫斯基、雅卡尔指数、半正矢、Sørensen-Dice
  • 项目实战--Spring Boot大数据量报表Excel优化
  • C#编程技术指南:从入门到精通的全面教程
  • Redis+定式任务实现简易版消息队列
  • 学习在 C# 中使用 Lambda 运算符
  • 数据结构和算法,单链表的实现(kotlin版)
  • Jdk17是否有可能代替 Jdk8
  • oca和 ocp有什么区别
  • 煤矿安全大模型:微调internlm2模型实现针对煤矿事故和煤矿安全知识的智能问答
  • C++中的C++中的虚析构函数的作用和重要性
  • 机器学习 - 文本特征处理之 TF 和 IDF
  • 因为自己淋过雨所以想给嵌入式撑把伞
  • 《C++20设计模式》中单例模式
  • 前端技术(说明篇)
  • 带电池监控功能的恒流直流负载组
  • 关于Disruptor监听策略