当前位置: 首页 > news >正文

YOLOv8如何添加注意力模块?

分为两种:有参注意力和无参注意力。
eg:
有参:

import torch
from torch import nnclass EMA(nn.Module):def __init__(self, channels, factor=8):super(EMA, self).__init__()self.groups = factorassert channels // self.groups > 0self.softmax = nn.Softmax(-1)self.agp = nn.AdaptiveAvgPool2d((1, 1))self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)def forward(self, x):b, c, h, w = x.size()group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,wx_h = self.pool_h(group_x)x_w = self.pool_w(group_x).permute(0, 1, 3, 2)hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))x_h, x_w = torch.split(hw, [h, w], dim=2)x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())x2 = self.conv3x3(group_x)x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwx21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwweights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)return (group_x * weights.sigmoid()).reshape(b, c, h, w)

无参:

import torch
import torch.nn as nnclass SimAM(torch.nn.Module):def __init__(self, e_lambda=1e-4):super(SimAM, self).__init__()self.activaton = nn.Sigmoid()self.e_lambda = e_lambdadef __repr__(self):s = self.__class__.__name__ + '('s += ('lambda=%f)' % self.e_lambda)return s@staticmethoddef get_module_name():return "simam"def forward(self, x):b, c, h, w = x.size()n = w * h - 1x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5return x * self.activaton(y)

1、在nn文件夹下新建attention.py文件,把上面俩代码放进去
在这里插入图片描述

2、在tasks.py文件里面导入俩函数
在这里插入图片描述
3、在解析函数里面添加解析代码
在这里插入图片描述
c1:上一层的输出通道数,也是这一层的输入通道数
C2:该层的输出通道数,即将成为下一层的输入通道数
args[]:每个带参数的模块,都要指定这个东西,这个包括[c1,c2,剩下的参数],然后传给该层的模块,有些模块不需要额外参数,就只传一个输出通道数给这一层就行
切记!!!C2是这一层的输出通道数,而args[]里的输入输出通道数是给模块的
4、新建模型配置文件
在这里插入图片描述
4、快速验证配置文件,新建main.py文件,然后运行

from ultralytics import YOLOif __name__=='__main__':print('11111111111')model=YOLO('/home/xxxxxxxx/v8/yolov8-main/ultralytics/models/v8/yolov8-att.yaml')

在这里插入图片描述
5、如果想修改这个参数,传进来
在这里插入图片描述
6、配置文件改也行,传进去
在这里插入图片描述
7、总结:放进attention.py,接着在tasks.py里注册,接着解析函数添加(有通道无通道),模型配置文件替换

8、第二种:在4、6、9后面加
在这里插入图片描述
在这里插入图片描述

http://www.lryc.cn/news/212140.html

相关文章:

  • 用LibreOffice在excel中画折线图
  • RabbitMQ 链接管理-发布者-消费者
  • JAVA中的垃圾回收器(3)----ZGC
  • IDEA 如何运行 SpringBoot 项目
  • Linux MeterSphere测试平台远程访问你不会?来试试这篇文章
  • 15.k8s集群防火墙配置
  • Python beautifulsoup网络抓取和解析cnblog首页帖子数据
  • Java集成腾讯云OCR身份证识别接口
  • C++之C++11引入enum class与传统enum关键字总结(二百五十一)
  • 如何将word格式的文档转换成markdown格式的文档
  • Leetcode—2558.从数量最多的堆取走礼物【简单】
  • 【如何写论文】硕博学位论文的结构框架、过程与大纲分析
  • 砷化镓(GaAs)纳米线 砷化镓纳米线 GaAs纳米线 瑞禧
  • PostGreSQL:JSON|JSONB数据类型
  • 树----数据结构
  • GitLab定时备份
  • SQL IN 运算符
  • 虚拟机构建单体项目及前后端分离项目
  • 代码浅析DLIO(一)---整体框架梳理
  • Springboot的Container Images,docker加springboot
  • c 从avi 视频中提取图片
  • Jtti:Apache服务的反向代理及负载均衡怎么配置
  • 82.二分查找
  • 线程是如何创建的
  • owl_vit安装步骤
  • 运行real.exe时出现NUM_METGRID_SOIL_LEVELS=0
  • 【数值计算方法】Gauss消元法及其Python/C实现
  • ins老被封禁?快来看看这些雷区你踩了没!
  • 《Effective Java》读书笔记(1-2章)
  • C++版split(‘_‘)函数