当前位置: 首页 > news >正文

Python day49.

@浙大疏锦行 Python day49

CBAM注意力,同时结合了通道注意力和空间注意力,通道注意力关注了通道的重要性,空间注意力关注了图像中不同位置的重要性,相当于同时提升通道和空间维度的特征质量;

import torch
import torch.nn as nn# 定义通道注意力
class ChannelAttention(nn.Module):def __init__(self, in_channels, ratio=16):"""通道注意力机制初始化参数:in_channels: 输入特征图的通道数ratio: 降维比例,用于减少参数量,默认为16"""super().__init__()# 全局平均池化,将每个通道的特征图压缩为1x1,保留通道间的平均值信息self.avg_pool = nn.AdaptiveAvgPool2d(1)# 全局最大池化,将每个通道的特征图压缩为1x1,保留通道间的最显著特征self.max_pool = nn.AdaptiveMaxPool2d(1)# 共享全连接层,用于学习通道间的关系# 先降维(除以ratio),再通过ReLU激活,最后升维回原始通道数self.fc = nn.Sequential(nn.Linear(in_channels, in_channels // ratio, bias=False),  # 降维层nn.ReLU(),  # 非线性激活函数nn.Linear(in_channels // ratio, in_channels, bias=False)   # 升维层)# Sigmoid函数将输出映射到0-1之间,作为各通道的权重self.sigmoid = nn.Sigmoid()def forward(self, x):"""前向传播函数参数:x: 输入特征图,形状为 [batch_size, channels, height, width]返回:调整后的特征图,通道权重已应用"""# 获取输入特征图的维度信息,这是一种元组的解包写法b, c, h, w = x.shape# 对平均池化结果进行处理:展平后通过全连接网络avg_out = self.fc(self.avg_pool(x).view(b, c))# 对最大池化结果进行处理:展平后通过全连接网络max_out = self.fc(self.max_pool(x).view(b, c))# 将平均池化和最大池化的结果相加并通过sigmoid函数得到通道权重attention = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)# 将注意力权重与原始特征相乘,增强重要通道,抑制不重要通道return x * attention #这个运算是pytorch的广播机制## 空间注意力模块
class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):# 通道维度池化avg_out = torch.mean(x, dim=1, keepdim=True)  # 平均池化:(B,1,H,W)max_out, _ = torch.max(x, dim=1, keepdim=True)  # 最大池化:(B,1,H,W)pool_out = torch.cat([avg_out, max_out], dim=1)  # 拼接:(B,2,H,W)attention = self.conv(pool_out)  # 卷积提取空间特征return x * self.sigmoid(attention)  # 特征与空间权重相乘## CBAM模块
class CBAM(nn.Module):def __init__(self, in_channels, ratio=16, kernel_size=7):super().__init__()self.channel_attn = ChannelAttention(in_channels, ratio)self.spatial_attn = SpatialAttention(kernel_size)def forward(self, x):x = self.channel_attn(x)x = self.spatial_attn(x)return x

http://www.lryc.cn/news/626164.html

相关文章:

  • 嵌入式第三十二天(信号,共享内存)
  • 机器学习概念(面试题库)
  • 8.19笔记
  • Python + 淘宝 API 开发:自动化采集商品数据的完整流程​
  • python新工具-uv包管理工具
  • RPC高频问题与底层原理剖析
  • Chrome插件开发【windows】
  • 【最新版】CRMEB Pro版v3.4系统源码全开源+PC端+uniapp前端+搭建教程
  • LLM(大语言模型)的工作原理 图文讲解
  • 网络间的通用语言TCP/IP-网络中的通用规则4
  • 大模型+RPA:如何用AI实现企业流程自动化的“降本增效”?
  • 基于SpringBoot+Vue的养老院管理系统的设计与实现 智能养老系统 养老架构管理 养老小程序
  • Linux系统部署python程序
  • SConscript 脚本入门教程
  • InfoNES模拟器HarmonyOS移植指南
  • Redis缓存加速测试数据交互:从前缀键清理到前沿性能革命
  • 图形化监控用数据动态刷新方法
  • Transformer入门到精通(附高清文档)
  • 内网后渗透攻击--隐藏通信隧道技术(压缩、上传,下载)
  • 常见的软件图片缩放,算法如何选择?
  • 【开源工具】基于社会工程学的智能密码字典生成器开发(附完整源码)
  • 字节开源了一款具备长期记忆能力的多模态智能体:M3-Agent
  • 洛谷 P2834 纸币问题 3-普及-
  • Flink原理与实践 · 第三章总结
  • 第5.6节:awk字符串运算
  • 【驱动】RK3576:桌面操作系统基本概念
  • L2TP虚拟局域网
  • 快速傅里叶变换:数字信号处理的基石算法
  • Orange的运维学习日记--47.Ansible进阶之异步处理
  • 数据库-MYSQL配置下载