当前位置: 首页 > news >正文

基于CBAM-CNN卷积神经网络预测研究(Python代码实现)

 💥💥💞💞欢迎来到本博客❤️❤️💥💥

🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。

⛳️座右铭:行百里者,半于九十。

📋📋📋本文目录如下:🎁🎁🎁

目录

💥1 概述

📚2 运行结果

🎉3 参考文献

🌈4 Python代码及数据


💥1 概述

CBAM(CBAM-CNN)是一种用于计算机视觉领域的卷积神经网络结构,它能够有效地从图像中学习关注和调整。CBAM模型结合了通道注意力模块(Channel Attention Module)和空间注意力模块(Spatial Attention Module)两个部分,用于提升卷积神经网络的性能。

通道注意力模块(CAM)旨在通过学习不同通道之间的相关性,为每个通道分配适当的注意力权重。该模块首先通过全局平均池化获得整个通道的平均值,然后使用两个全连接层来生成一组注意力权重。这些权重用于调整每个通道的特征图。

空间注意力模块(SAM)旨在学习图像中不同空间区域的重要性。该模块通过对特征图在不同空间维度上进行最大池化和平均池化操作,然后使用一个卷积层来生成一组注意力权重。最后,这些权重被应用于原始特征图,以增强具有重要空间信息的区域。

通过结合通道注意力模块和空间注意力模块,CBAM能够动态地选择和调整特征图的通道和空间注意力,从而提取更准确和具有区分力的特征表示。这种注意力机制有助于网络更好地对图像进行感知,从而改善图像分类、目标检测、图像分割等计算机视觉任务的性能。

针对预测任务,可以使用CBAM-CNN模型进行图像分类或目标检测。在图像分类任务中,CBAM-CNN可以通过自适应地关注重要的通道和空间区域,提取图像特征并进行分类。在目标检测任务中,CBAM-CNN可以辅助检测网络对目标区域进行准确定位和分类。

需要注意的是,CBAM-CNN只是一种网络结构,具体的预测研究还需要根据具体的任务和数据集进行调整和优化。

📚2 运行结果

 部分代码:

def forward(self, x):# 1.最大池化分支max_branch = self.MaxPool(x)# 送入MLP全连接神经网络, 得到权重max_in = max_branch.view(max_branch.size(0), -1)max_weight = self.fc_MaxPool(max_in)# 2.全局池化分支avg_branch = self.AvgPool(x)# 送入MLP全连接神经网络, 得到权重avg_in = avg_branch.view(avg_branch.size(0), -1)avg_weight = self.fc_AvgPool(avg_in)# MaxPool + AvgPool 激活后得到权重weightweight = max_weight + avg_weightweight = self.sigmoid(weight)# 将维度为b, c的weight, reshape成b, c, 1, 1 与 输入x 相乘h, w = weight.shape# 通道注意力McMc = torch.reshape(weight, (h, w, 1))# 乘积获得结果x = Mc * xreturn xclass SpatialAttentionModul(nn.Module):  # 空间注意力模块def __init__(self, in_channel):super(SpatialAttentionModul, self).__init__()self.conv = nn.Conv1d(2, 1, 7, padding=3)self.sigmoid = nn.Sigmoid()def forward(self, x):# x维度为 [N, C, H, W] 沿着维度C进行操作, 所以dim=1, 结果为[N, H, W]MaxPool = torch.max(x, dim=1).values  # torch.max 返回的是索引和value, 要用.values去访问值才行!AvgPool = torch.mean(x, dim=1)# 增加维度, 变成 [N, 1, H, W]MaxPool = torch.unsqueeze(MaxPool, dim=1)AvgPool = torch.unsqueeze(AvgPool, dim=1)# 维度拼接 [N, 2, H, W]x_cat = torch.cat((MaxPool, AvgPool), dim=1)  # 获得特征图# 卷积操作得到空间注意力结果x_out = self.conv(x_cat)Ms = self.sigmoid(x_out)# 与原图通道进行乘积x = Ms * xreturn xif __name__ == '__main__':inputs = torch.randn(32, 512, 16)model = CBAM(in_channel=512)  # CBAM模块, 可以插入CNN及任意网络中, 输入特征图in_channel的维度
    def forward(self, x):# 1.最大池化分支max_branch = self.MaxPool(x)# 送入MLP全连接神经网络, 得到权重max_in = max_branch.view(max_branch.size(0), -1)max_weight = self.fc_MaxPool(max_in)# 2.全局池化分支avg_branch = self.AvgPool(x)# 送入MLP全连接神经网络, 得到权重avg_in = avg_branch.view(avg_branch.size(0), -1)avg_weight = self.fc_AvgPool(avg_in)# MaxPool + AvgPool 激活后得到权重weightweight = max_weight + avg_weightweight = self.sigmoid(weight)# 将维度为b, c的weight, reshape成b, c, 1, 1 与 输入x 相乘h, w = weight.shape# 通道注意力McMc = torch.reshape(weight, (h, w, 1))# 乘积获得结果x = Mc * xreturn xclass SpatialAttentionModul(nn.Module):  # 空间注意力模块def __init__(self, in_channel):super(SpatialAttentionModul, self).__init__()self.conv = nn.Conv1d(2, 1, 7, padding=3)self.sigmoid = nn.Sigmoid()def forward(self, x):# x维度为 [N, C, H, W] 沿着维度C进行操作, 所以dim=1, 结果为[N, H, W]MaxPool = torch.max(x, dim=1).values  # torch.max 返回的是索引和value, 要用.values去访问值才行!AvgPool = torch.mean(x, dim=1)# 增加维度, 变成 [N, 1, H, W]MaxPool = torch.unsqueeze(MaxPool, dim=1)AvgPool = torch.unsqueeze(AvgPool, dim=1)# 维度拼接 [N, 2, H, W]x_cat = torch.cat((MaxPool, AvgPool), dim=1)  # 获得特征图# 卷积操作得到空间注意力结果x_out = self.conv(x_cat)Ms = self.sigmoid(x_out)# 与原图通道进行乘积x = Ms * xreturn xif __name__ == '__main__':inputs = torch.randn(32, 512, 16)model = CBAM(in_channel=512)  # CBAM模块, 可以插入CNN及任意网络中, 输入特征图in_channel的维度

🎉3 参考文献

部分理论来源于网络,如有侵权请联系删除。

[1]黄昌顺,张金萍.基于CBAM-CNN的滚动轴承故障诊断方法[J].现代制造工程,2022(11):137-143.DOI:10.16731/j.cnki.1671-3133.2022.11.022.

[2]杜先君,巩彬,余萍等.基于CBAM-CNN的模拟电路故障诊断[J].控制与决策,2022,37(10):2609-2618.DOI:10.13195/j.kzyjc.2021.1111.

🌈4 Python代码及数据

http://www.lryc.cn/news/138742.html

相关文章:

  • iOS开发Swift-基本运算符
  • Flink java 工具类
  • 2023年你需要知道的最佳预算Wi-Fi路由器清单
  • Go语言基础之流程控制
  • Git 安装、配置并把项目托管到码云 Gitee
  • C++信息学奥赛1147:最高分数的学生姓名
  • STM32使用PID调速
  • 【UE5:CesiumForUnreal】——3DTiles数据属性查询和单体高亮
  • 无涯教程-PHP - 返回类型声明
  • DOS常见命令
  • Qt应用开发(拓展篇)——示波器/图表 QCustomPlot
  • 【精度丢失】后端接口返回的Long类型参数,不同浏览器解析出的结果不一样
  • 2023年国赛 高教社杯数学建模思路 - 案例:感知机原理剖析及实现
  • java-红黑树
  • vue2 vue中的常用指令
  • AI驱动下的智能制造:工业自动化的新纪元
  • docker 命令
  • 2023年高教社杯数学建模思路 - 复盘:光照强度计算的优化模型
  • 生成式人工智能的潜在有害影响与未来之路(二)
  • 如何自己实现一个丝滑的流程图绘制工具(三)自定义挂载vue组件
  • UNIAPP调用API接口
  • 理解 Delphi 的类(五) - 认识类的继承
  • mybatis概述及搭建
  • DNDC模型---土壤碳储量、温室气体排放、农田减排、土地变化、气候变化中的应用
  • Android studio 2022.3.1 鼠标移动时不显示快速文档
  • 五度易链最新“产业大数据服务解决方案”亮相,打造数据引擎,构建智慧产业!
  • 简述hive环境搭建
  • 小米AI音箱联网升级折腾记录(解决配网失败+升级失败等问题)
  • tensorRT安装
  • 电脑重装+提升网速