当前位置: 首页 > news >正文

yolov8添加ca注意力机制

创建文件 coordAtt.py

位置:ultralytics/nn/modules/coordAtt.py

######################  CoordAtt  ####     start   by  AI&CV  ###############################
# https://zhuanlan.zhihu.com/p/655475515
import torch
import torch.nn as nn
import torch.nn.functional as Fclass h_sigmoid(nn.Module):def __init__(self, inplace=True):super(h_sigmoid, self).__init__()self.relu = nn.ReLU6(inplace=inplace)def forward(self, x):return self.relu(x + 3) / 6class h_swish(nn.Module):def __init__(self, inplace=True):super(h_swish, self).__init__()self.sigmoid = h_sigmoid(inplace=inplace)def forward(self, x):return x * self.sigmoid(x)class CoordAtt(nn.Module):def __init__(self, inp, reduction=32):super(CoordAtt, self).__init__()self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))mip = max(8, inp // reduction)self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.act = h_swish()self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)def forward(self, x):identity = xn, c, h, w = x.size()x_h = self.pool_h(x)x_w = self.pool_w(x).permute(0, 1, 3, 2)y = torch.cat([x_h, x_w], dim=2)y = self.conv1(y)y = self.bn1(y)y = self.act(y)x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)a_h = self.conv_h(x_h).sigmoid()a_w = self.conv_w(x_w).sigmoid()out = identity * a_w * a_hreturn out
######################  CoordAtt  ####     end   by  AI&CV  ###############################

conv.py中添加头文件

位置:ultralytics/nn/modules/conv.py
在这里插入图片描述

init.py中添加头文件

位置 :ultralytics/nn/modules/init.py
在这里插入图片描述

tasks.py 文件

位置:ultralytics/nn/tasks.py

task.py文件 添加头文件

在这里插入图片描述

task.py文件的 方法中添加代码

        elif m is CoordAtt: # todo 源码修改 ~4"""ch[f]:上一层的args[0]:第0个参数c1:输入通道数c2:输出通道数"""c1, c2 = ch[f], args[0]# print("ch[f]:",ch[f])# print("args[0]:",args[0])# print("args:",args)# print("c1:",c1)# print("c2:",c2)if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)c2 = make_divisible(c2 * width, 8)args = [c1, *args[1:]]

在这里插入图片描述

运行效果

在这里插入图片描述

对比图(左:未添加cbam,右上:添加cbam,右下:添加ca)

在这里插入图片描述

yolov8添加ca注意力机制-出现bug

ImportError: cannot import name ‘CoordAtt’ from ‘ultralytics.nn.modules’ (D:\anaconda3\envs\torch\lib\site-packages\ultralytics\nn\modules_init_.py)

在这里插入图片描述

解决方法:拷贝项目中左图文件,到环境配置的右图目录中

在这里插入图片描述

ImportError: cannot import name ‘CoordAtt’ from ‘ultralytics.nn.modules.conv’ (D:\anaconda3\envs\torch\lib\site-packages\ultralytics\nn\modules\conv.py)

在这里插入图片描述

解决方法:拷贝项目中左图文件,到环境配置的右图目录中

在这里插入图片描述

ModuleNotFoundError: No module named ‘ultralytics.nn.modules.coordAtt’

在这里插入图片描述
解决方法:拷贝项目中左图文件,到环境配置的右图目录中
在这里插入图片描述

http://www.lryc.cn/news/253115.html

相关文章:

  • linux java后台启动的几种方式
  • selinux-policy-default(2:2.20231119-2)软件包内容详细介绍(5)
  • 代码随想录二刷 |栈与队列 |理论基础
  • java--接口概述
  • 出海风潮:中国母婴品牌征服国际市场的机遇与挑战!
  • 一文读懂MongoDB的知识点(3),惊呆面试官。
  • ssm的“魅力”西安宣传网站(有报告)。Javaee项目。
  • 怎么让SecureCRT不自动断开连接
  • 介绍几种Go语言开发的IDE
  • 1、设计模式简介(7大原则,3大类)
  • 华为鲲鹏+银河麒麟V10编译FreeSWITCH1.10.9
  • CFS三层靶机内网渗透
  • 软件分享--智能照片识别分类软件
  • Leetcode—409.最长回文串【简单】
  • 计算机网络入侵检测技术研究
  • 深入学习锁--Synchronized各种使用方法
  • pycharm中绘制一个3D曲线
  • 人工智能_AI服务器安装清华开源_CHATGLM大语言模型_GLM-6B安装部署_人工智能工作笔记0092
  • 用户反馈组件实现(Vue3+ElementPlus)含图片拖拽上传
  • K8S部署nginx并且使用NFS存储数据
  • Homework 3: Higher-Order Functions, Self Reference, Recursion, Tree Recursion
  • (C++)有效三角形的个数--双指针法
  • 11.30BST理解,AVL树操作,定义;快速幂,二分求矩阵幂(未完)
  • 深入理解Java核心技术:Java工程师的实用干货笔记
  • 大学里面转专业介绍
  • MySQL_1. mysql数据库介绍
  • TimeGPT:时间序列预测模型实例
  • 【JavaEE】多线程 (1)
  • linux 应用层同步与互斥机制之条件变量
  • 3.5毫米音频连接器接线方式