当前位置: 首页 > news >正文

卷积神经网络中的注意力机制:CBAM详解与实践

一、引言

在深度学习领域,卷积神经网络(CNN)一直是计算机视觉任务的主流架构。然而,传统的CNN对所有空间位置和通道特征一视同仁,缺乏对重要特征的聚焦能力。注意力机制的引入为解决这一问题提供了思路。

本文将重点介绍卷积注意力模块CBAM(Convolutional Block Attention Module),并详细讲解如何在PyTorch中实现和应用这一机制。

二、注意力机制概述

2.1 什么是注意力机制

注意力机制源于人类视觉系统的工作方式 - 我们不会同时处理视野中的所有信息,而是选择性地聚焦于重要部分。在深度学习中,注意力机制通过动态调整特征图中不同位置或通道的重要性,使模型能够关注更有信息量的区域。

2.2 注意力机制的分类

  1. 空间注意力:关注"在哪里"(Where)重要,在特征图的二维空间维度上分配权重

  2. 通道注意力:关注"什么"(What)重要,在不同通道维度上分配权重

  3. 混合注意力:同时考虑空间和通道注意力

CBAM就是一种典型的混合注意力机制,它依次应用通道注意力和空间注意力模块,显著提升了模型性能。

三、CBAM详解

3.1 CBAM结构

CBAM由两个顺序子模块组成:

  1. 通道注意力模块(Channel Attention Module)

  2. 空间注意力模块(Spatial Attention Module)

3.2 通道注意力模块

通道注意力关注"什么"是有意义的输入图像。它通过挤压(squeeze)操作聚合空间信息,然后通过激励(excitation)操作学习通道间的依赖关系。

数学表达:

Mc(F) = σ(MLP(AvgPool(F)) + MLP(MaxPool(F)))
= σ(W1(W0(F_avg)) + W1(W0(F_max))) 

其中:

  • F:输入特征图

  • σ:sigmoid函数

  • W0 ∈ R^(C/r×C), W1 ∈ R^(C×C/r)

  • r是缩减比率

3.3 空间注意力模块

空间注意力关注"在哪里"是信息丰富的部分。它通过在通道维度上应用平均池化和最大池化,然后 concatenate 起来生成有效的特征描述符。

数学表达:

Ms(F) = σ(f^7×7([AvgPool(F); MaxPool(F)]))
= σ(f^7×7([F_avg; F_max])) 

其中:

  • f^7×7:7×7卷积核

  • σ:sigmoid函数

  • [·; ·]:concatenate操作

四、PyTorch实现CBAM

4.1 通道注意力模块实现

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ChannelAttention(nn.Module):def __init__(self, in_channels, reduction_ratio=16):"""通道注意力模块初始化参数:in_channels (int): 输入特征图的通道数reduction_ratio (int): MLP中间层的缩减比率,默认16"""super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 全局平均池化self.max_pool = nn.AdaptiveMaxPool2d(1)  # 全局最大池化# 共享的两层MLPself.mlp = nn.Sequential(nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),  # 降维nn.ReLU(inplace=True),nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)   # 升维)self.sigmoid = nn.Sigmoid()def forward(self, x):"""前向传播参数:x (torch.Tensor): 输入特征图,形状为[B, C, H, W]返回:torch.Tensor: 通道注意力权重,形状同输入"""avg_out = self.mlp(self.avg_pool(x))  # 平均池化路径max_out = self.mlp(self.max_pool(x))  # 最大池化路径channel_weights = self.sigmoid(avg_out + max_out)  # 结合两种池化结果return x * channel_weights  # 应用注意力权重

4.2 空间注意力模块实现 

class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):"""空间注意力模块初始化参数:kernel_size (int): 卷积核大小,必须是奇数,默认7"""super(SpatialAttention, self).__init__()assert kernel_size % 2 == 1, "kernel_size必须是奇数"padding = kernel_size // 2  # 保持特征图尺寸不变self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):"""前向传播参数:x (torch.Tensor): 输入特征图,形状为[B, C, H, W]返回:torch.Tensor: 空间注意力权重,形状为[B, 1, H, W]"""# 在通道维度上同时应用平均池化和最大池化avg_out = torch.mean(x, dim=1, keepdim=True)  # 平均池化 [B, 1, H, W]max_out, _ = torch.max(x, dim=1, keepdim=True)  # 最大池化 [B, 1, H, W]# 拼接两种池化结果combined = torch.cat([avg_out, max_out], dim=1)  # [B, 2, H, W]# 通过卷积层生成空间注意力图spatial_weights = self.sigmoid(self.conv(combined))  # [B, 1, H, W]return x * spatial_weights  # 应用注意力权重

4.3 完整CBAM模块 

class CBAM(nn.Module):def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):"""CBAM模块初始化参数:in_channels (int): 输入特征图的通道数reduction_ratio (int): 通道注意力中的缩减比率,默认16kernel_size (int): 空间注意力中的卷积核大小,必须是奇数,默认7"""super(CBAM, self).__init__()self.channel_attention = ChannelAttention(in_channels, reduction_ratio)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):"""前向传播参数:x (torch.Tensor): 输入特征图,形状为[B, C, H, W]返回:torch.Tensor: 经过CBAM处理后的特征图,形状同输入"""# 先应用通道注意力,再应用空间注意力x = self.channel_attention(x)x = self.spatial_attention(x)return x

五、将CBAM集成到CNN中

5.1 基本残差块集成CBAM

class BasicBlockWithCBAM(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1, downsample=None):"""带有CBAM的残差块初始化参数:in_channels (int): 输入通道数out_channels (int): 输出通道数stride (int): 卷积步长,默认1downsample (nn.Module): 下采样模块,默认None"""super(BasicBlockWithCBAM, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.downsample = downsampleself.stride = stride# 添加CBAM模块self.cbam = CBAM(out_channels)def forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)# 应用CBAMout = self.cbam(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out

5.2 完整ResNet集成CBAM示例 

class ResNetWithCBAM(nn.Module):def __init__(self, block, layers, num_classes=1000):"""ResNet集成CBAM的完整实现参数:block (nn.Module): 基础块类型,如BasicBlock或Bottlenecklayers (list): 每个阶段的块数量,如[2, 2, 2, 2]对应ResNet18num_classes (int): 分类类别数,默认1000"""super(ResNetWithCBAM, self).__init__()self.in_channels = 64# 初始卷积层self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 四个残差阶段self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)# 分类头self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, out_channels, blocks, stride=1):"""创建残差阶段参数:block (nn.Module): 基础块类型out_channels (int): 输出通道数blocks (int): 块数量stride (int): 第一个块的步长返回:nn.Sequential: 残差阶段"""downsample = Noneif stride != 1 or self.in_channels != out_channels * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * block.expansion),)layers = []layers.append(block(self.in_channels, out_channels, stride, downsample))self.in_channels = out_channels * block.expansionfor _ in range(1, blocks):layers.append(block(self.in_channels, out_channels))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return x

六、CBAM在不同任务中的应用

6.1 图像分类任务

# 创建带有CBAM的ResNet18模型
def resnet18_cbam(num_classes=1000):return ResNetWithCBAM(BasicBlockWithCBAM, [2, 2, 2, 2], num_classes)model = resnet18_cbam(num_classes=10)# 训练配置示例
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

6.2 目标检测任务

在Faster R-CNN中集成CBAM:

from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone# 创建带有CBAM的ResNet-FPN骨干网络
def resnet_fpn_cbam(pretrained=False):backbone = resnet_fpn_backbone('resnet50', pretrained)# 在骨干网络的特定层添加CBAMdef add_cbam(layer):return nn.Sequential(layer, CBAM(layer[-1].out_channels))backbone.body.layer1 = add_cbam(backbone.body.layer1)backbone.body.layer2 = add_cbam(backbone.body.layer2)backbone.body.layer3 = add_cbam(backbone.body.layer3)backbone.body.layer4 = add_cbam(backbone.body.layer4)return backbone# 创建Faster R-CNN模型
backbone = resnet_fpn_cbam(pretrained=True)
model = FasterRCNN(backbone, num_classes=91)  # COCO数据集有90类+背景

6.3 语义分割任务

在U-Net中集成CBAM:

class DoubleConvWithCBAM(nn.Module):"""(convolution => [BN] => ReLU) * 2 + CBAM"""def __init__(self, in_channels, out_channels):super().__init__()self.double_conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))self.cbam = CBAM(out_channels)def forward(self, x):x = self.double_conv(x)x = self.cbam(x)return xclass UNetWithCBAM(nn.Module):def __init__(self, n_classes):super(UNetWithCBAM, self).__init__()# 编码器部分self.inc = DoubleConvWithCBAM(3, 64)self.down1 = DownWithCBAM(64, 128)self.down2 = DownWithCBAM(128, 256)self.down3 = DownWithCBAM(256, 512)self.down4 = DownWithCBAM(512, 1024)# 解码器部分self.up1 = UpWithCBAM(1024, 512)self.up2 = UpWithCBAM(512, 256)self.up3 = UpWithCBAM(256, 128)self.up4 = UpWithCBAM(128, 64)self.outc = OutConv(64, n_classes)def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logits

七、CBAM的变体与改进

7.1 轻量级CBAM

class LightweightCBAM(nn.Module):def __init__(self, in_channels, reduction_ratio=8, kernel_size=7):super(LightweightCBAM, self).__init__()# 简化通道注意力self.channel_attention = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(in_channels, in_channels // reduction_ratio, 1),nn.ReLU(inplace=True),nn.Conv2d(in_channels // reduction_ratio, in_channels, 1),nn.Sigmoid())# 简化空间注意力self.spatial_attention = nn.Sequential(nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2),nn.Sigmoid())def forward(self, x):# 通道注意力channel = self.channel_attention(x)x = x * channel# 空间注意力avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)spatial = self.spatial_attention(torch.cat([avg_out, max_out], dim=1))x = x * spatialreturn x

7.2 并行CBAM 

class ParallelCBAM(nn.Module):def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):super(ParallelCBAM, self).__init__()self.channel_attention = ChannelAttention(in_channels, reduction_ratio)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):channel = self.channel_attention(x)spatial = self.spatial_attention(x)# 并行结合方式return x * (channel + spatial) / 2  # 平均结合

八、实验与性能分析

8.1 在CIFAR-10上的对比实验

模型参数量(M)准确率(%)训练时间(秒/epoch)
ResNet1811.294.545
ResNet18+CBAM11.395.848
ResNet3421.395.168
ResNet34+CBAM21.596.372

8.2 可视化分析

CBAM的注意力图可以可视化,帮助我们理解模型关注的重点区域:

import matplotlib.pyplot as pltdef visualize_attention(model, image_tensor):# 前向传播获取中间特征features = []def hook_fn(module, input, output):features.append(output)# 注册钩子hook = model.layer4[-1].cbam.register_forward_hook(hook_fn)# 前向传播model.eval()with torch.no_grad():_ = model(image_tensor.unsqueeze(0))# 移除钩子hook.remove()# 获取注意力图feature = features[0]channel_weights = model.layer4[-1].cbam.channel_attention(feature)spatial_weights = model.layer4[-1].cbam.spatial_attention(feature * channel_weights)# 可视化plt.figure(figsize=(12, 4))plt.subplot(1, 3, 1)plt.imshow(image_tensor.permute(1, 2, 0))plt.title("Original Image")plt.axis('off')plt.subplot(1, 3, 2)plt.imshow(channel_weights[0, 0].cpu().numpy(), cmap='hot')plt.title("Channel Attention")plt.axis('off')plt.subplot(1, 3, 3)plt.imshow(spatial_weights[0, 0].cpu().numpy(), cmap='hot')plt.title("Spatial Attention")plt.axis('off')plt.show()

九、总结与展望

CBAM作为一种简单有效的注意力机制,通过顺序应用通道和空间注意力模块,显著提升了CNN模型的性能。本文详细介绍了CBAM的原理、PyTorch实现方法以及在不同任务中的应用方式。实验表明,CBAM能够以较小的计算代价带来明显的性能提升。

未来发展方向:

  1. 更高效的注意力计算方式

  2. 动态调整注意力模块的数量和位置

  3. 与其他注意力机制(如self-attention)的结合

  4. 在轻量化网络中的应用优化

十、参考文献

  1. Woo, S., Park, J., Lee, J. Y., & Kweon, I. S. (2018). "CBAM: Convolutional Block Attention Module". ECCV.

  2. Hu, J., Shen, L., & Sun, G. (2018). "Squeeze-and-Excitation Networks". CVPR.

  3. Wang, X., Girshick, R., Gupta, A., & He, K. (2018). "Non-local Neural Networks". CVPR.

希望这篇详细的教程能够帮助你理解和应用CBAM注意力机制!在实际项目中,可以根据具体任务需求调整CBAM的位置和参数,以达到最佳效果。

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

http://www.lryc.cn/news/594728.html

相关文章:

  • Go-通俗易懂垃圾回收及运行过程
  • WPF——自定义ListBox
  • C++ - 仿 RabbitMQ 实现消息队列--服务端核心模块实现(二)
  • 学习秒杀系统-异步下单(包含RabbitMQ基础知识)
  • ASP.NET Core Web API 中集成 DeveloperSharp.RabbitMQ
  • 关于校准 ARM 开发板时间的步骤和常见问题:我应该是RTC电池没电了才导致我设置了重启开发板又变回去2025年的时间
  • Android NDK ffmpeg 音视频开发实战
  • 什么是“差分“?
  • 包装类简单了解泛型
  • 图片转 PDF三个免费方法总结
  • 支持不限制大小,大文件分段批量上传功能(不受nginx /apache 上传大小限制)
  • 网络设备功能对照表
  • 【Spark征服之路-3.6-Spark-SQL核心编程(五)】
  • Linux 文件操作详解:结构、系统调用、权限与实践
  • 第二阶段-第二章—8天Python从入门到精通【itheima】-134节(SQL——DQL——分组聚合)
  • leetcode-sql-627变更性别
  • 深入解析IP协议:组成、地址管理与路由选择
  • Tomato靶机通关教程
  • 安装docker可视化工具 Portainer中文版(ubuntu上演示,所有docker通用) 支持控制各种容器,容器操作简单化 降低容器门槛
  • 板凳-------Mysql cookbook学习 (十二--------4)
  • 技能学习PostgreSQL中级专家
  • 借助AI学习开源代码git0.7之六write-cache
  • 基于 STM32 的数字闹钟系统 Proteus 仿真设计与实现
  • 从一开始的网络攻防(六):php反序列化
  • 金仓数据库:融合进化,智领未来——2025年数据库技术革命的深度解析
  • STM32 USB键盘实现指南
  • 最严电动自行车新规,即将实施!
  • FreeSwitch通过Websocket(流式双向语音)对接AI实时语音大模型技术方案(mod_ppy_aduio_stream)
  • 朝歌智慧盘古信息:以IMS MOM V6重构国产化智能终端新生态
  • 【初识数据结构】CS61B中的最小生成树问题