当前位置: 首页 > news >正文

【目标检测实验系列】YOLOv5高效涨点:基于NAMAttention规范化注意力模块,调整权重因子关注有效特征(文内附源码)

1. 文章主要内容

       本篇博客主要涉及规范化注意力机制,融合到YOLOv5(v6.1版本,去掉了Focus模块)模型中,通过惩罚机制,调整特征权重因子,使模型更加关注有效特征,助力模型涨点。

2. 简要概括

       论文地址:NAM论文地址
       论文Github代码:Github代码

       NAM注意力机制在2021年的时候就挂在arxiv上,博主最近逛了一逛发现其github代码的关键模块中,还是缺乏了论文当中的空间注意力模块,只提供了通道注意力模块,所以这篇论文的NAM在代码层面上只利用了通道注意特征,如下图所示。
在这里插入图片描述
       亮点在于:NAM的核心思想在于通过调整,利用稀疏的权重惩罚来降低不太显著的特征(换句话说:对显著有效特征更加关注)的权重,使得整体注意力权重在计算上保持同样性能的情况下变得更加高效,助力模型高效涨点,有兴趣的可以阅读原论文!

       分析:NAM也是一个即插即用的注意力模块,可以融合到YOLOv5网络结构中的任何地方,前提是通道等维度对齐。另外,因为论文代码只提高了通道注意力且一般情况下,高维度的通道特征比较丰富,换句话说网络深度越深,通道数越高,其高层次的语义特征也就会越丰富,所以建议将NAM放在网络更深层次,有助于提取丰富的高层次特征,助力模型涨点!下面给出NAM原论文中的一个结构图,注意只针对于通道注意力!
在这里插入图片描述

3. 详细代码改进流程

       接下来记录一下将NAM添加到YOLOv5模型中某一个地方的实验过程。注意到(在后面的yolov5-NAM.yaml中体现):本文是将NAM添加在检测大目标的检测头的前面,也就是23层 (P5/32-large)的后面,添加了一层,后面的Detect序号也得增加一,变成[[17, 20, 24], 1, Detect, [nc, anchors]]!

3.1新建一个NAM的py文件,放置源代码

       首先新建一个NAM.py存放其源代码,博主在此文件中还提供了一个main函数的测试案例,启动可以正常输出,就证明模块木有问题,通道数对得上。

import torch.nn as nn
import torchclass Channel_Att(nn.Module):def __init__(self, channels, t=16):super(Channel_Att, self).__init__()self.channels = channelsself.bn2 = nn.BatchNorm2d(self.channels, affine=True)def forward(self, x):residual = xx = self.bn2(x)weight_bn = self.bn2.weight.data.abs() / torch.sum(self.bn2.weight.data.abs())x = x.permute(0, 2, 3, 1).contiguous()x = torch.mul(weight_bn, x)x = x.permute(0, 3, 1, 2).contiguous()x = torch.sigmoid(x) * residual  #return xclass NAMAttention(nn.Module):def __init__(self, channels, out_channels=None, no_spatial=True):super(NAMAttention, self).__init__()self.Channel_Att = Channel_Att(channels)def forward(self, x):x_out1 = self.Channel_Att(x)return x_out1if __name__ == '__main__':model = NAMAttention(64)inputs = torch.randn((1, 64, 64, 64))print(model(inputs).size())

3.2新建一个yolov5-NAM.yaml文件

       然后,新建一个yolov5-NAM.yaml文件,同时 注意nc改为自己数据集的类别数另外,yaml文件中NAMAttention的位置其实可以放置在任何地方,只需要调试好通道数输入输出即可。

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 10  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8  小目标- [30,61, 62,45, 59,119]  # P4/16 中目标- [116,90, 156,198, 373,326]  # P5/32  大目标# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2  output_channel, kernel_size, stride, padding[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]],  # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[-1, 1, NAMAttention, [1024]],# 修改[[17, 20, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

3.3 将NAM引入到yolo.py文件中

       在下图的红色圈内位置处,引入NAMAttention,并手动导入相应的包即可。代码和示意图如下:

        elif m is NAMAttention:c1, c2 = ch[f], args[0]if c2 != no:c2 = make_divisible(c2 * gw, 8)args = [c1, *args[1:]]

在这里插入图片描述

3.4 修改train.py启动文件

       修改配置文件为yolov5-NAM.yaml即可,如下图所示:
在这里插入图片描述

4. 总结

       本篇博客主要介绍了规范化注意力机制NAM,通过惩罚机制,降低不显著特征,助力YOLOv5模型涨点。另外,在修改过程中,要是有任何问题,评论区交流;如果博客对您有帮助,请帮忙点个赞,收藏一下;后续会持续更新本人实验当中觉得有用的点子,如果很感兴趣的话,可以关注一下,谢谢大家啦!

http://www.lryc.cn/news/418334.html

相关文章:

  • LSPatch制作内置模块应用软件无需root 教你制作内置应用
  • Java设计模式七大原则
  • Copy as cURL 字段含义
  • mysql更改密码后,若依 后端启动不了解决方案
  • Redis--缓存击穿、缓存穿透、缓存雪崩
  • 10个理由告诉你,为什么鸿蒙是下一个职业风口!
  • Gitlab仓库的权限分配以及如何查看自己的权限
  • 职业本科大数据实训室
  • 【密码学】网络攻击类型:窃听攻击、假冒攻击、欺骗攻击和重放攻击
  • 探索WebKit的奥秘:塑造高效、兼容的现代网页应用
  • 2-52 基于matlab局部信息的模糊C均值聚类算法(FLICM)
  • JAVASE
  • SpringBoot学习之EasyExcel解析合并单元格(三十九)
  • 【Kimi学习笔记】C/C++、C#、Java 和 Python
  • 基于贪心算法的路径优化
  • 谷粒商城实战笔记-140-商城业务-nginx-搭建域名访问环境二(负载均衡到网关)
  • 【Android Studio】 创建第一个Android应用HelloWorld
  • C++中的错误处理机制:异常
  • 概率论原理精解【9】
  • Pytorch添加自定义算子之(11)-C++应用程序将onnx模型编译并转成tensorrt可执行模型
  • C++笔记1•C++入门基础•
  • Linux查看系统线程数
  • 【Python基础】Python六种标准数据类型中哪些是可变数据,哪些是不可变数据
  • android13去掉安全模式 删除安全模式
  • LeetCode239 滑动窗口最大值
  • 文件解析漏洞—IIS解析漏洞—IIS7.X
  • vue中子传父之间通信(this.$emit触发父组件方法和.sync修饰符与$emit(update:xxx))
  • SocketIO 的 html 代码示例
  • Vercel Error: (Azure) OpenAI API key not found
  • SPSS、Python员工满意度问卷调查激励保健理论研究:决策树、随机森林和AdaBoost|附代码数据