当前位置: 首页 > news >正文

每日Attention学习19——Convolutional Multi-Focal Attention

每日Attention学习19——Convolutional Multi-Focal Attention

模块出处

[ICLR 25 Submission] [link] UltraLightUNet: Rethinking U-shaped Network with Multi-kernel Lightweight Convolutions for Medical Image Segmentation


模块名称

Convolutional Multi-Focal Attention (CMFA)


模块作用

轻量解码器


模块结构

在这里插入图片描述


模块特点
  • 使用最大池化与平均池化构建通道注意力
  • 使用Channel Max与Channel Average构建空间注意力
  • 核心思想与CBAM较为类似,串联通道注意力与空间注意力

模块代码
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7, 11), 'kernel size must be 3 or 7 or 11'padding = kernel_size // 2self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)x = self.conv(x)return self.sigmoid(x)class ChannelAttention(nn.Module):def __init__(self, in_planes, out_planes=None, ratio=16):super(ChannelAttention, self).__init__()self.in_planes = in_planesself.out_planes = out_planesif self.in_planes < ratio:ratio = self.in_planesself.reduced_channels = self.in_planes // ratioif self.out_planes == None:self.out_planes = in_planesself.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.activation = nn.ReLU(inplace=True)self.fc1 = nn.Conv2d(in_planes, self.reduced_channels, 1, bias=False)self.fc2 = nn.Conv2d(self.reduced_channels, self.out_planes, 1, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_pool_out = self.avg_pool(x) avg_out = self.fc2(self.activation(self.fc1(avg_pool_out)))max_pool_out= self.max_pool(x)max_out = self.fc2(self.activation(self.fc1(max_pool_out)))out = avg_out + max_outreturn self.sigmoid(out) class CMFA(nn.Module):def __init__(self, in_planes, out_planes=None,):super(CMFA, self).__init__()self.ca = ChannelAttention(in_planes=64, out_planes=64)self.sa = SpatialAttention()def forward(self, x):x = x*self.ca(x)x = x*self.sa(x)return xif __name__ == '__main__':x = torch.randn([1, 64, 44, 44])cmfa = CMFA(in_planes=64, out_planes=64)out = cmfa(x)print(out.shape)  # [1, 64, 44, 44]

http://www.lryc.cn/news/531880.html

相关文章:

  • LeetCode题练习与总结:三个数的最大乘积--628
  • Colorful/七彩虹 隐星P15 TA 24 原厂Win11 家庭版系统 带F9 Colorful一键恢复功能
  • 第二篇:多模态技术突破——DeepSeek如何重构AI的感知与认知边界
  • CTreeCtrl 设置图标
  • 在JAX-RS中获取请求头信息的方法
  • Java 面试之结束问答
  • 柔性数组与c/c++程序中内存区域的划分
  • mini-lsm通关笔记Week2Day7
  • Typora免费使用
  • AI驱动的无线定位:基础、标准、最新进展与挑战
  • 苹果再度砍掉AR眼镜项目?AR真的是伪风口吗?
  • 18 大量数据的异步查询方案
  • DRM系列八:Drm之DRM_IOCTL_MODE_ADDFB2
  • 软件测试用例篇
  • PopupMenuButton组件的功能和用法
  • Python进行模型优化与调参
  • vue2-组件通信
  • 20250205确认荣品RK3566开发板在Android13下可以使用命令行reboot -p关机
  • 设计模式---观察者模式
  • 初八开工!开启数字化转型新征程!
  • 文本分析NLP的常用工具和特点
  • DeepSeek 与 ChatGPT 对比分析
  • vite---依赖优化选项esbuildOptions详解
  • ElasticSearch 学习课程入门(二)
  • 使用 Redis Streams 实现高性能消息队列
  • 深度学习|表示学习|卷积神经网络|DeconvNet是什么?|18
  • (优先级队列(堆)) 【本节目标】 1. 掌握堆的概念及实现 2. 掌握 PriorityQueue 的使用
  • 优化数据库结构
  • 密云生活的初体验
  • 图像分类与目标检测算法