当前位置: 首页 > news >正文

【Block总结】ODConv动态卷积,适用于CV任务|即插即用

一、论文信息

  • 论文标题:Omni-Dimensional Dynamic Convolution
  • 作者:Chao Li, Aojun Zhou, Anbang Yao
  • 发表会议:ICLR 2022
  • 论文链接:https://arxiv.org/pdf/2209.07947
  • GitHub链接:https://github.com/OSVAI/ODConv
    在这里插入图片描述

二、创新点

Omni-Dimensional Dynamic Convolution(ODConv)提出了一种更为通用且优雅的动态卷积设计,主要创新点包括:

  • 多维动态注意力机制:ODConv通过并行策略在卷积核的四个维度(空间大小、输入通道数、输出通道数和卷积核数量)上学习互补的注意力。这种设计使得卷积核能够根据输入特征动态调整,从而提升特征提取能力。

  • 即插即用的特性:ODConv可以作为常规卷积的替代品,轻松集成到现有的CNN架构中,增强模型的灵活性和适应性。

三、方法

ODConv的实现方法包括以下几个步骤:

  1. 注意力计算

    • ODConv计算四种类型的注意力:空间注意力、输入通道注意力、输出通道注意力和卷积核注意力。这些注意力值用于调节卷积核的输出。
  2. 并行策略

    • 在每个卷积层中,ODConv并行计算上述四种注意力,确保每个卷积核在不同维度上都能获得适当的加权。
  3. 卷积操作

    • 将计算得到的注意力应用于卷积核,进而影响最终的特征图输出。

在这里插入图片描述

ODConv的多维动态注意力机制实现

Omni-Dimensional Dynamic Convolution(ODConv)引入了一种创新的多维动态注意力机制,旨在提升卷积神经网络(CNN)的特征提取能力。该机制通过并行策略在卷积核的四个维度上学习互补的注意力,从而实现更灵活的卷积操作。以下是ODConv多维动态注意力机制的具体实现细节:

1、四个维度的注意力机制

ODConv的多维动态注意力机制主要涉及以下四个维度的注意力学习:

  1. 空间维度注意力(Spatial Attention)

    • 该注意力机制为每个卷积核的不同空间位置分配不同的权重。通过对空间特征的加权,ODConv能够更好地捕捉图像中的局部特征。
  2. 输入通道注意力(Input Channel Attention)

    • 该机制为每个卷积核的输入通道分配不同的权重,允许模型根据输入特征的重要性动态调整卷积操作。这种方式增强了模型对不同输入特征的响应能力。
  3. 输出通道注意力(Output Channel Attention)

    • 该注意力机制为每个卷积核的输出通道分配不同的权重,使得模型能够根据输出特征的重要性进行动态调整,从而优化特征表示。
  4. 卷积核数量注意力(Kernel Attention)

    • 该机制为每个卷积核分配不同的权重,允许模型在多个卷积核之间进行选择,增强了模型的灵活性和适应性。

2、并行策略

ODConv采用并行策略来计算上述四种类型的注意力。具体实现步骤如下:

  • 注意力计算

    • 在每个卷积层中,ODConv并行计算四种注意力,分别对应于卷积核的四个维度。这些注意力值通过多头注意力模块进行计算,以确保每个维度的特征都能得到充分的关注。
  • 注意力加权

    • 计算得到的注意力值被应用于卷积核的输出,进而影响最终的特征图。这种加权机制使得卷积操作能够根据输入特征的不同动态调整,从而提升特征提取的效果。

3、优势与效果

ODConv的多维动态注意力机制带来了显著的性能提升:

  • 增强特征学习能力:通过在多个维度上进行动态调整,ODConv能够更有效地捕捉图像中的重要特征。

  • 减少参数量:即使在使用单个卷积核的情况下,ODConv也能与现有的多核动态卷积方法竞争或超越,显著减少了额外的参数。

  • 广泛适用性:ODConv可以作为常规卷积的替代品,轻松集成到现有的CNN架构中,提升模型的灵活性和适应性。

四、效果

ODConv在多个标准数据集上进行了实验,结果显示其在准确性和效率上均有显著提升:

  • ImageNet:在MobileNetV2和ResNet系列模型上,ODConv分别提升了3.77%至5.71%和1.86%至3.72%的Top-1准确率。

  • MS-COCO:在目标检测任务中,ODConv同样展现了优越的性能,提升了模型对小目标和被遮挡目标的检测能力。

五、实验结果

ODConv的实验结果表明,其在多个主流CNN架构上的表现均优于传统卷积方法。具体实验结果包括:

  • MobileNetV2

    • 原始模型Top-1准确率为71.65%,使用ODConv后提升至74.74%(1×核)和75.29%(4×核)。
  • ResNet系列

    • ResNet50的Top-1准确率从76.23%提升至77.87%(1×核)和78.50%(4×核)。

这些结果表明,ODConv不仅提高了模型的准确性,还在参数量上保持了较低的增长。

六、总结

Omni-Dimensional Dynamic Convolution(ODConv)通过引入多维动态注意力机制,显著提升了卷积神经网络的特征提取能力。其创新的设计使得ODConv能够在多个维度上学习卷积核的动态特性,进而提高模型的性能。实验结果证明,ODConv在多个标准数据集上均表现出色,成为现代深度学习模型中一种有效的卷积替代方案。

代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autogradclass Attention(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16):super(Attention, self).__init__()attention_channel = max(int(in_planes * reduction), min_channel)self.kernel_size = kernel_sizeself.kernel_num = kernel_numself.temperature = 1.0self.avgpool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Conv2d(in_planes, attention_channel, 1, bias=False)self.bn = nn.BatchNorm2d(attention_channel)self.relu = nn.ReLU(inplace=True)self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True)self.func_channel = self.get_channel_attentionif in_planes == groups and in_planes == out_planes:  # depth-wise convolutionself.func_filter = self.skipelse:self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True)self.func_filter = self.get_filter_attentionif kernel_size == 1:  # point-wise convolutionself.func_spatial = self.skipelse:self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True)self.func_spatial = self.get_spatial_attentionif kernel_num == 1:self.func_kernel = self.skipelse:self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True)self.func_kernel = self.get_kernel_attentionself._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)if isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)def update_temperature(self, temperature):self.temperature = temperature@staticmethoddef skip(_):return 1.0def get_channel_attention(self, x):channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)return channel_attentiondef get_filter_attention(self, x):filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)return filter_attentiondef get_spatial_attention(self, x):spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size)spatial_attention = torch.sigmoid(spatial_attention / self.temperature)return spatial_attentiondef get_kernel_attention(self, x):kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1)kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1)return kernel_attentiondef forward(self, x):x = self.avgpool(x)x = self.fc(x)x = self.bn(x)x = self.relu(x)return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x)class ODConv2d(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1,reduction=0.0625, kernel_num=4):super(ODConv2d, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.kernel_size = kernel_sizeself.stride = strideself.padding = paddingself.dilation = dilationself.groups = groupsself.kernel_num = kernel_numself.attention = Attention(in_planes, out_planes, kernel_size, groups=groups,reduction=reduction, kernel_num=kernel_num)self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//groups, kernel_size, kernel_size),requires_grad=True)self._initialize_weights()if self.kernel_size == 1 and self.kernel_num == 1:self._forward_impl = self._forward_impl_pw1xelse:self._forward_impl = self._forward_impl_commondef _initialize_weights(self):for i in range(self.kernel_num):nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')def update_temperature(self, temperature):self.attention.update_temperature(temperature)def _forward_impl_common(self, x):# Multiplying channel attention (or filter attention) to weights and feature maps are equivalent,# while we observe that when using the latter method the models will run faster with less gpu memory cost.channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)batch_size, in_planes, height, width = x.size()x = x * channel_attentionx = x.reshape(1, -1, height, width)aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)aggregate_weight = torch.sum(aggregate_weight, dim=1).view([-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,dilation=self.dilation, groups=self.groups * batch_size)output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))output = output * filter_attentionreturn outputdef _forward_impl_pw1x(self, x):channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)x = x * channel_attentionoutput = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,dilation=self.dilation, groups=self.groups)output = output * filter_attentionreturn outputdef forward(self, x):return self._forward_impl(x)if __name__ == "__main__":dim=256# 如果GPU可用,将模块移动到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, height, width,channels)x = torch.randn(2,dim,40,40).to(device)# 初始化 HWD 模块block = ODConv2d(dim,dim,7,padding=3)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)

输出结果:
在这里插入图片描述

http://www.lryc.cn/news/529696.html

相关文章:

  • RK3568 opencv播放视频
  • 《LLM大语言模型+RAG实战+Langchain+ChatGLM-4+Transformer》
  • 【搜索回溯算法篇】:拓宽算法视野--BFS如何解决拓扑排序问题
  • 计算机网络 (61)移动IP
  • Elasticsearch+kibana安装(简单易上手)
  • 音视频多媒体编解码器基础-codec
  • 【算法与数据结构】动态规划
  • DeepSeekMoE:迈向混合专家语言模型的终极专业化
  • 什么是Maxscript?为什么要学习Maxscript?
  • HyperLogLog 近似累计去重技术解析:大数据场景下的高效基数统计
  • LabVIEW透镜多参数自动检测系统
  • MySQL数据库(二)- SQL
  • 【Block总结】HiLo注意力,局部自注意力捕获细粒度的高频信息,通过全局注意力捕获低频信息|即插即用
  • python 使用Whisper模型进行语音翻译
  • C# Winform enter键怎么去关联button
  • Github 2025-01-30 Go开源项目日报 Top10
  • 电路研究9.2.6——合宙Air780EP中HTTP——HTTP GET 相关命令使用方法研究
  • Java手写简单Merkle树
  • DeepSeek的使用技巧介绍
  • 19 压测和常用的接口优化方案
  • AI应用部署——streamlit
  • NLP自然语言处理通识
  • C++ 6
  • 使用QSqlQueryModel创建交替背景色的表格模型
  • jinfo命令详解
  • 如何在 ACP 中建模复合罐
  • 【Java】微服务找不到问题记录can not find user-service
  • 基于Hutool的Merkle树hash值生成工具
  • Windows系统本地部署deepseek 更改目录
  • 深度学习篇---数据存储类型