当前位置: 首页 > news >正文

目标检测算法改进系列之嵌入Deformable ConvNets v2 (DCNv2)

Deformable ConvNets v2

简介:由于构造卷积神经网络所用的模块中几何结构是固定的,其几何变换建模的能力本质上是有限的。在DCN v1中引入了两种新的模块来提高卷积神经网络对变换的建模能力,即可变形卷积 (deformable convolution) 和可变形兴趣区域池化 (deformable ROI pooling)。它们都是基于在模块中对空间采样的位置信息作进一步位移调整的想法,该位移可在目标任务中学习得到,并不需要额外的监督信号。新的模块可以很方便在现有的卷积神经网络 中取代它们的一般版本,并能很容易进行标准反向传播端到端的训练,从而得到可变形卷积网络 (deformable convolutional network)。但是增加偏移之后可能会将无关的信息考虑进去,影响最终的结果。所以在DCN v2中作者对DCN v1进行了提升,减小无关信息的干扰。

原文地址:Deformable ConvNets v2: More Deformable, Better Results

regular conv
DCNv1
DCNv2

pytorch代码实现

class DCNv2(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1,padding=1, dilation=1, groups=1, deformable_groups=1):super(DCNv2, self).__init__()self.in_channels = in_channelsself.out_channels = out_channelsself.kernel_size = (kernel_size, kernel_size)self.stride = (stride, stride)self.padding = (padding, padding)self.dilation = (dilation, dilation)self.groups = groupsself.deformable_groups = deformable_groupsself.weight = nn.Parameter(torch.empty(out_channels, in_channels, *self.kernel_size))self.bias = nn.Parameter(torch.empty(out_channels))out_channels_offset_mask = (self.deformable_groups * 3 *self.kernel_size[0] * self.kernel_size[1])self.conv_offset_mask = nn.Conv2d(self.in_channels,out_channels_offset_mask,kernel_size=self.kernel_size,stride=self.stride,padding=self.padding,bias=True,)self.bn = nn.BatchNorm2d(out_channels)self.act = Conv.default_actself.reset_parameters()def forward(self, x):offset_mask = self.conv_offset_mask(x)o1, o2, mask = torch.chunk(offset_mask, 3, dim=1)offset = torch.cat((o1, o2), dim=1)mask = torch.sigmoid(mask)x = torch.ops.torchvision.deform_conv2d(x,self.weight,offset,mask,self.bias,self.stride[0], self.stride[1],self.padding[0], self.padding[1],self.dilation[0], self.dilation[1],self.groups,self.deformable_groups,True)x = self.bn(x)x = self.act(x)return xdef reset_parameters(self):n = self.in_channelsfor k in self.kernel_size:n *= kstd = 1. / math.sqrt(n)self.weight.data.uniform_(-std, std)self.bias.data.zero_()self.conv_offset_mask.weight.data.zero_()self.conv_offset_mask.bias.data.zero_()class Bottleneck_DCN(nn.Module):# Standard bottleneck with DCNdef __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):  # ch_in, ch_out, shortcut, groups, kernels, expandsuper().__init__()c_ = int(c2 * e)  # hidden channelsif k[0] == 3:self.cv1 = DCNv2(c1, c_, k[0], 1)else:self.cv1 = Conv(c1, c_, k[0], 1)if k[1] == 3:self.cv2 = DCNv2(c_, c2, k[1], 1, groups=g)else:self.cv2 = Conv(c_, c2, k[1], 1, g=g)self.add = shortcut and c1 == c2def forward(self, x):return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))class C2f_DCN(nn.Module):# CSP Bottleneck with 2 convolutionsdef __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansionsuper().__init__()self.c = int(c2 * e)  # hidden channelsself.cv1 = Conv(c1, 2 * self.c, 1, 1)self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)self.m = nn.ModuleList(Bottleneck_DCN(self.c, self.c, shortcut, g, k=(3, 3), e=1.0) for _ in range(n))def forward(self, x):y = list(self.cv1(x).split((self.c, self.c), 1))y.extend(m(y[-1]) for m in self.m)return self.cv2(torch.cat(y, 1))

具体修改

module.py文件修改

将pytorch代码实现中的定义代码添加至module.py文件最后
修改1

task.py文件修改

导入C2f-DCN模块
在这里插入图片描述
def parse_model函数部分导入C2f-DCN
在这里插入图片描述

yolov8.yaml配置文件修改

替换原有C2f模块,最后进行训练即可。
在这里插入图片描述

http://www.lryc.cn/news/210556.html

相关文章:

  • 最新发布!阿里云卓越架构框架重磅升级
  • 如何监听/抓取两个设备/芯片之间“UART串口”通信数据--监视TXD和RXD
  • JDK项目分析的经验分享
  • Java创建一个长度为10的数组,利用Arrays.sort(), 为数组元素排序
  • python 动态加载C# 动态库的一些问题
  • 代码审计-锐捷NBR路由器 EWEB网管系统 远程命令执行
  • VBA技术资料MF75:测量所选单元格范围的高度和宽度
  • 力扣 26. 删除有序数组中的重复项
  • 【uniapp】仿微信支付界面
  • windows + ubuntu + vscode开发环境配置安装
  • 设计模式:责任链模式(C#、JAVA、JavaScript、C++、Python、Go、PHP)
  • koa搭建服务器(二)
  • LeetCode 125 验证回文串 简单
  • Android底层摸索改BUG(一):Android系统状态栏显示不下Wifi图标
  • 第十三章---枚举类型与泛型
  • shell语法大全(超级详细!!!!),非常适合入门
  • 【Python机器学习】零基础掌握ExtraTreesRegressor集成学习
  • 网络协议--TCP的交互数据流
  • IOC课程整理-13 Spring校验
  • SSM咖啡点餐管理系统开发mysql数据库web结构java编程计算机网页源码eclipse项目
  • Capacitor 打包 h5 到 Android 应用,uniapp https http net::ERR_CLEARTEXT_NOT_PERMITTED
  • 华为数通方向HCIP-DataCom H12-831题库(多选题:101-120)
  • misc学习(4)Traffic(流量分析)-
  • Less的基本语法
  • spring boot项目优雅停机
  • 链式存储方式下字符串的replace(S,T1,T2)运算
  • unity脚本_Mathf和Math c#
  • 轻量级仿 Spring Boot=嵌入式 Tomcat+Spring MVC
  • 笔记Kubernetes核心技术-之Controller
  • Azure云工作站上做Machine Learning模型开发 - 全流程演示