当前位置: 首页 > news >正文

【Block总结】高效多尺度注意力EMA,超越SE、CBAM、SA、CA等注意力|即插即用

论文信息

标题: Efficient Multi-Scale Attention Module with Cross-Spatial Learning

作者: Daliang Ouyang, Su He, Guozhong Zhang, Mingzhu Luo, Huaiyong Guo, Jian Zhan, Zhijie Huang

论文链接: https://arxiv.org/pdf/2305.13563v2

GitHub链接: https://github.com/YOLOonMe/EMA-attention-module
在这里插入图片描述

创新点

该论文提出了一种新颖的高效多尺度注意力模块(EMA),旨在通过跨空间学习来提升特征表示的效果,同时降低计算开销。EMA模块的设计重点在于:

  • 信息保留: 在每个通道上保留信息,确保特征的完整性。
  • 计算效率: 通过重塑部分通道为批处理维度,减少计算负担。
  • 多尺度学习: 结合多尺度特征,增强模型对不同尺度信息的捕捉能力。

方法

EMA模块的核心方法包括:

  1. 通道重塑: 将部分通道重塑为批处理维度,并将通道维度分组为多个子特征,以实现更高效的信息处理。

  2. 跨维度交互: 通过跨维度交互,聚合两个并行分支的输出特征,捕获像素级的成对关系。

  3. 并行子网络: 设计多尺度并行子网络,以建立短期和长期依赖关系,从而增强特征表示能力。

在这里插入图片描述

EMA模块的信息保留与计算效率平衡

信息保留机制

EMA(Efficient Multi-Scale Attention)模块通过以下几种方式实现信息的有效保留:

  1. 通道重塑: EMA模块将部分通道重塑为批处理维度,并将通道维度分组为多个子特征。这种设计确保了每个通道的信息能够被有效保留,同时避免了通道维度的削减,从而增强了特征的表达能力[1][3]。

  2. 跨维度交互: 在EMA模块中,两个并行分支的输出特征通过跨维度交互进行聚合。这种交互机制能够捕捉到像素级的成对关系,从而进一步提升特征的丰富性和准确性[2][3]。

  3. 多尺度并行子网络: EMA模块采用了多尺度并行子网络结构,结合了1x1和3x3卷积核的特征处理。这种结构能够有效捕获不同尺度的信息,确保在特征提取过程中不会丢失重要信息[2][3]。

计算效率提升

在计算效率方面,EMA模块通过以下方式优化了计算过程:

  1. 减少计算开销: 通过将部分通道重塑为批处理维度,EMA模块能够在不显著增加计算成本的情况下,保持高效的信息处理。这种方法使得模型在处理大规模数据时更加高效[1][2]。

  2. 并行处理: EMA模块的设计允许多个子网络并行处理特征,这不仅提高了计算效率,还减少了模型的顺序处理需求,从而加快了整体计算速度[3]。

  3. 适度的模型尺寸: EMA模块的设计确保了模型的尺寸适中,适合在移动终端等资源受限的环境中部署。这种设计使得EMA模块在保持性能的同时,能够有效降低计算资源的消耗[3][2]。

EMA模块通过创新的设计实现了信息保留与计算效率的平衡。其通道重塑、跨维度交互和多尺度并行处理的策略,不仅确保了特征信息的完整性,还显著提高了计算效率。这使得EMA模块在计算机视觉任务中表现出色,尤其是在小目标检测和图像分类等应用中,展现了其广泛的应用潜力和实际意义。

效果

实验结果表明,EMA模块在多个计算机视觉任务中表现优异,尤其是在小目标检测和图像分类任务中,相较于传统的注意力机制(如ECA、CBAM、CA),EMA模块显著提高了特征表示的清晰度和准确性。

实验结果

在广泛的消融研究和实验中,EMA模块在以下数据集上进行了评估:

  • CIFAR-100
  • ImageNet-1k
  • MS COCO
  • VisDrone2019

实验结果显示,EMA模块在这些基准测试中均取得了优于现有方法的性能,尤其在小目标检测任务中,表现出明显的优势。

总结

Efficient Multi-Scale Attention Module with Cross-Spatial Learning通过创新的设计和有效的实现,成功地提升了计算机视觉任务中的特征表示能力,同时降低了计算复杂度。该模块的提出为未来的研究提供了新的思路,尤其是在需要高效处理大规模数据的应用场景中,EMA模块展现了其广泛的应用潜力。

代码

import torch
from torch import nnclass EMA(nn.Module):def __init__(self, channels, c2=None, factor=32):super(EMA, self).__init__()self.groups = factorassert channels // self.groups > 0self.softmax = nn.Softmax(-1)self.agp = nn.AdaptiveAvgPool2d((1, 1))self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)def forward(self, x):b, c, h, w = x.size()group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,wx_h = self.pool_h(group_x)x_w = self.pool_w(group_x).permute(0, 1, 3, 2)hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))x_h, x_w = torch.split(hw, [h, w], dim=2)x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())x2 = self.conv3x3(group_x)x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwx21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hwweights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)return (group_x * weights.sigmoid()).reshape(b, c, h, w)if __name__ == "__main__":# 如果GPU可用,将模块移动到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, channels, height, width)x = torch.randn(1,32,40,40).to(device)# 初始化 pconv 模块dim=32block = EMA(dim,factor=8)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)

输出结果:

在这里插入图片描述

http://www.lryc.cn/news/528921.html

相关文章:

  • Pwn 入门核心工具和命令大全
  • 探索AI(chatgpt、文心一言、kimi等)提示词的奥秘
  • 利用飞书机器人进行 - ArXiv自动化检索推荐
  • 小白爬虫冒险之反“反爬”:无限debugger、禁用开发者工具、干扰控制台...(持续更新)
  • Ubuntu中MySQL安装-02
  • 大数据相关职位介绍之一(数据分析,数据开发,数据产品经理,数据运营)
  • 使用DeepSeek API生成Markdown文件
  • java多线程学习笔记
  • Manticore Search,新一代搜索引擎之王
  • 【MySQL】数据类型与表约束
  • CAG技术:提升LLM响应速度与质量
  • 上位机知识篇---Linux源码编译安装链接命令
  • 科研绘图系列:R语言绘制线性回归连线图(line chart)
  • 将ollama迁移到其他盘(eg:F盘)
  • Oracle 创建用户和表空间
  • cursor ide配置远程ssh qt c++开发环境过程记录
  • yolov5错误更改与相关参数详解(train.py)
  • Python设计模式 - 组合模式
  • css粘性定位超出指定宽度失效问题
  • Windows 程序设计6:错误码的查看
  • doris: CSV导入数据
  • FastStone Image Viewer图像处理软件安装步骤(百度网盘链接)
  • Kafka 深入服务端 — 时间轮
  • 网络爬虫学习:应用selenium获取Edge浏览器版本号,自动下载对应版本msedgedriver,确保Edge浏览器顺利打开。
  • 【go语言】结构体
  • Spring Boot是什么及其优点
  • 谷氨酸:大脑功能的多面手
  • SpringCloudGateWay和Sentinel结合做黑白名单来源控制
  • HTML新春烟花
  • 【Elasticsearch】中数据流需要配置索引模板吗?