当前位置: 首页 > news >正文

YOLOv9中加入SCConv模块!

 


专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,主力高效涨点!!!


一、本文介绍

        本文将一步步演示如何在YOLOv9中添加 / 替换新模块,寻找模型上的创新!

适用检测目标:   YOLOv9模块通用改进


二、改进步骤

《YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information》

        论文地址:   https://arxiv.org/abs/2402.13616

        代码地址:   https://github.com/WongKinYiu/yolov9

 2.1 创建一个脚本存放新模块

        为方便调用,这里我将脚本放在models包下,命名为extra.py。

 2.2 将模块复制到脚本中,并导入需要的包(以SCConv为例)

        我们将SCConv的代码复制到刚刚创建的extra.py脚本中。

import torch
import torch.nn as nn
import torch.nn.functional as Ffrom models.common import Convclass SCConv(nn.Module):"""https://github.com/MCG-NKU/SCNet/blob/master/scnet.py"""def __init__(self, inplanes, planes, stride=1, padding=1, dilation=1, groups=1, pooling_r=4):super(SCConv, self).__init__()self.k2 = nn.Sequential(nn.AvgPool2d(kernel_size=pooling_r, stride=pooling_r),Conv(inplanes, planes, k=3, s=1, p=padding, d=dilation, g=groups, act=False))self.k3 = Conv(inplanes, planes, k=3, s=1, p=padding, d=dilation, g=groups, act=False)self.k4 = Conv(inplanes, planes, k=3, s=1, p=padding, d=dilation, g=groups, act=False)def forward(self, x):identity = xout = torch.sigmoid(torch.add(identity, F.interpolate(self.k2(x), identity.size()[2:]))) # sigmoid(identity + k2)out = torch.mul(self.k3(x), out)    # k3 * sigmoid(identity + k2)out = self.k4(out)  # k4return out

2.3 对yolo.py操作

        打开models包下的yolo.py文件夹,将刚才创建的脚本导入。并在下方第700行的位置(位置可能因v9版本更新变动)加入下方代码。

2.4 运行配置文件

        创建模型配置文件(yaml文件),将我们所作改进加入到配置文件中(这一步的配置文件可以复制models  - > detect 下的yaml修改。)。对YOLO系列yaml文件不熟悉的同学可以看我往期的yaml详解教学!

YOLO系列 “.yaml“文件解读-CSDN博客

# YOLOv9# parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
#activation: nn.LeakyReLU(0.1)
#activation: nn.ReLU()# anchors
anchors: 3# YOLOv9 backbone
backbone:[[-1, 1, Silence, []],  # conv down[-1, 1, Conv, [64, 3, 2]],  # 1-P1/2# conv down[-1, 1, Conv, [128, 3, 2]],  # 2-P2/4# elan-1 block[-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 3# avg-conv down[-1, 1, ADown, [256]],  # 4-P3/8# elan-2 block[-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 5# avg-conv down[-1, 1, ADown, [512]],  # 6-P4/16# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 7# avg-conv down[-1, 1, ADown, [512]],  # 8-P5/32# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 9]# YOLOv9 head
head:[# elan-spp block[-1, 1, SPPELAN, [512, 256]],  # 10# up-concat merge[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 7], 1, Concat, [1]],  # cat backbone P4# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 13# up-concat merge[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 5], 1, Concat, [1]],  # cat backbone P3# elan-2 block[-1, 1, RepNCSPELAN4, [256, 256, 128, 1]],  # 16 (P3/8-small)# avg-conv-down merge[-1, 1, ADown, [256]],[[-1, 13], 1, Concat, [1]],  # cat head P4# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 19 (P4/16-medium)# avg-conv-down merge[-1, 1, ADown, [512]],[[-1, 10], 1, Concat, [1]],  # cat head P5# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 22 (P5/32-large)# multi-level reversible auxiliary branch# routing[5, 1, CBLinear, [[256]]], # 23[7, 1, CBLinear, [[256, 512]]], # 24[9, 1, CBLinear, [[256, 512, 512]]], # 25# conv down[0, 1, Conv, [64, 3, 2]],  # 26-P1/2# conv down[-1, 1, Conv, [128, 3, 2]],  # 27-P2/4# elan-1 block[-1, 1, RepNCSPELAN4, [256, 128, 64, 1]],  # 28# avg-conv down fuse[-1, 1, ADown, [256]],  # 29-P3/8[[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30  # elan-2 block[-1, 1, RepNCSPELAN4, [512, 256, 128, 1]],  # 31# avg-conv down fuse[-1, 1, ADown, [512]],  # 32-P4/16[[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33 # elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 34# avg-conv down fuse[-1, 1, ADown, [512]],  # 35-P5/32[[25, -1], 1, CBFuse, [[2]]], # 36# elan-2 block[-1, 1, RepNCSPELAN4, [512, 512, 256, 1]],  # 37[-1, 1, SCConv, []],  # 38# detection head# detect[[31, 34, 38, 16, 19, 22], 1, DualDDetect, [nc]],  # DualDDetect(A3, A4, A5, P3, P4, P5)]

3.4 训练过程

        最后,复制我们创建的模型配置,填入训练脚本(train_dual)中(不会训练的同学可以参考我之前的文章。),运行即可。

YOLOv9 最简训练教学!-CSDN博客


如果觉得本文章有用的话给博主点个关注吧!


http://www.lryc.cn/news/308488.html

相关文章:

  • 代码随想录算法训练营第四十七天丨198. 打家劫舍、​ 213. 打家劫舍 II​、337. 打家劫舍 III
  • 龙蜥Anolis 8.4 anck 安装mysql5.7
  • 【踩坑】修复xrdp无法关闭Authentication Required验证窗口
  • python学习笔记 - 标准库常量
  • 视频和音频使用ffmpeg进行合并和分离(MP4)
  • 02| JVM堆中垃圾回收的大致过程
  • R语言数据可视化之美专业图表绘制指南(增强版):第1章 R语言编程与绘图基础
  • 网站添加pwa操作和配置manifest.json后,没有效果排查问题
  • MongoDB聚合运算符:$cosh
  • Jenkins配置在远程服务器上执行shell脚本(两种方式)
  • Java+SpringBoot,打造社区疫情信息新生态
  • js ES6判断字符串是否以某个字符串开头或者结尾startsWith、endsWith
  • 预研项目完成后小批量验证(技术变更流程)
  • Bert-as-service 实战
  • 微信小程序(四十七)多个token存储
  • 机器学习(II)--样本不平衡
  • 几个好用的 VUE Table
  • Vue源码系列讲解——实例方法篇【三】(生命周期相关方法)
  • 百度SEO工具,自动更新网站的工具
  • 供应链|NUS覃含章MS论文解读:数据驱动下联合定价和库存控制的近似方法 (二)
  • 删除有序数组中的重复项Ⅱ
  • Java底层自学大纲_数据结构和算法篇
  • 群晖NAS配置WebDav结合内网穿透实现公网访问本地影视资源
  • Vue3报错Promise executor functions should not be async.
  • (正规api接口代发布权限)短视频账号矩阵系统实现开发--技术全自动化saas营销链路生态
  • 【Redis】redis通用命令
  • mysql服务治理
  • opencv--使用直方图找谷底进行确定分割阈值
  • dolphinscheduler海豚调度(四)钉钉告警
  • Java-Safe Point(安全点)