当前位置: 首页 > news >正文

YOLOv5修改注意力机制CBAM

直接上干货

CBAM注意力机制是由通道注意力机制(channel)和空间注意力机制(spatial)组成。

传统基于卷积神经网络的注意力机制更多的是关注对通道域的分析,局限于考虑特征图通道之间的作用关系。CBAM从 channel 和 spatial 两个作用域出发,引入空间注意力和通道注意力两个分析维度,实现从通道到空间的顺序注意力结构。空间注意力可使神经网络更加关注图像中对分类起决定作用的像素区域而忽略无关紧要的区域,通道注意力则用于处理特征图通道的分配关系,同时对两个维度进行注意力分配增强了注意力机制对模型性能的提升效果。
 

CBAM中的通道注意力机制模块流程图如下。先将输入特征图分别进行全局最大池化和全局平均池化,对特征映射基于两个维度压缩,获得两张不同维度的特征描述。池化后的特征图共用一个多层感知器网络,先通过一个全连接层下降通道数,再通过另一个全连接恢复通道数。将两张特征图在通道维度堆叠,经过 sigmoid 激活函数将特征图的每个通道的权重归一化到0-1之间。将归一化后的权重和输入特征图相乘。

yaml 配置文件如下

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 6  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
anchors:- [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, CBAM, [1024]],[-1, 1, SPPF, [1024, 5]],  # 10]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 14[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 18 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 15], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 21 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 11], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 24 (P5/32-large)[[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

common加入以下代码

# CBAM
class ChannelAttention(nn.Module):def __init__(self, in_planes, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)self.relu = nn.ReLU()self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))max_out = self.f2(self.relu(self.f1(self.max_pool(x))))out = self.sigmoid(avg_out + max_out)return outclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1# (特征图的大小-算子的size+2*padding)/步长+1self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):# 1*h*wavg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)# 2*h*wx = self.conv(x)# 1*h*wreturn self.sigmoid(x)class CBAM(nn.Module):# CSP Bottleneck with 3 convolutionsdef __init__(self, c1, c2, ratio=16, kernel_size=7):  # ch_in, ch_out, number, shortcut, groups, expansionsuper(CBAM, self).__init__()self.channel_attention = ChannelAttention(c1, ratio)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):out = self.channel_attention(x) * x# c*h*w# c*h*w * 1*h*wout = self.spatial_attention(out) * outreturn out

YOLO 的

parse_model 注册

到此完成

后续会给大家讲解YOLOv8怎么修改

http://www.lryc.cn/news/118125.html

相关文章:

  • 计算机网络 网络层 概述
  • 算法练习--动态规划 相关
  • JAVA volatile 关键字
  • [Leetcode] [Tutorial] 回溯
  • STM32 CubeMX USB_MSC(存储设备U盘)
  • 湘大 XTU OJ 1214 A+B IV 题解:数位移动的本质+布尔变量标记+朴素模拟
  • 以商业大数据技术助力数据合规流通体系建立,合合信息参编《数据经纪从业人员评价规范》团标
  • 【论文阅读】Deep Instance Segmentation With Automotive Radar Detection Points
  • 易服客工作室:如何创建有用的内容日历
  • Excel革命,基于电子表格开发的新工具,不是Access和Power Fx
  • “崩溃”漏洞会影响英特尔 CPU 的使用寿命,可能会泄露加密密钥等
  • 17.电话号码的字母组合(回溯)
  • Redis小例子
  • ETLCloud+MaxCompute实现云数据仓库的高效实时同步
  • HTTP代理授权方式介绍
  • 《合成孔径雷达成像算法与实现》Figure3.4
  • qt5.15.2 使用mysql8.1
  • 广州华锐互动:VR3D课程在线教育平台为职业院校提供沉浸式的虚拟现实学习体验
  • clion run qt 问题汇总
  • 深入理解spring面经
  • 2023年,App运行小游戏,可以玩出什么创意?
  • 景嘉微电子2021笔试题
  • selenium官网文档阅读总结(day 4)
  • 15.4 【Linux】可唤醒停机期间的工作任务
  • [FPGA开发]解决正点原子Xilinx下载器无法下载、灯不亮的问题
  • DP(区间DP)
  • MySQL5.7保姆级安装教程
  • Linux:getopts解析命令行选项和参数
  • c语言——三子棋
  • Android 广播阻塞、延迟问题分析方法