当前位置: 首页 > news >正文

【深度学习注意力机制系列】—— SENet注意力机制(附pytorch实现)

深度学习中的注意力机制(Attention Mechanism)是一种模仿人类视觉和认知系统的方法,它允许神经网络在处理输入数据时集中注意力于相关的部分。通过引入注意力机制,神经网络能够自动地学习并选择性地关注输入中的重要信息,提高模型的性能和泛化能力。

卷积神经网络引入的注意力机制主要有以下几种方法:

  • 在空间维度上增加注意力机制
  • 在通道维度上增加注意力机制
  • 在两者的混合维度上增加注意力机制

我们将在本系列对多种注意力机制进行讲解,并使用pytorch进行实现,今天我们讲解SENet注意力机制

SENet(Squeeze-and-Excitation Networks)注意力机制通道维度上引入注意力机制,其核心思想在于通过网络根据loss去学习特征权重,使得有效的feature map权重大,无效或效果小的feature map权重小的方式训练模型达到更好的结果。SE block嵌在原有的一些分类网络中不可避免地增加了一些参数和计算量,但是在效果面前还是可以接受的 。Sequeeze-and-Excitation(SE) block并不是一个完整的网络结构,而是一个子结构,可以嵌到其他分类或检测模型中。

在这里插入图片描述

以上是SENet的结构示意图, 其关键操作为squeeze和excitation. 通过自动学习获得特征图在每个通道上的重要程度,以此为不同通道赋予不同的权重,提升有用通道的贡献程度.

实现机制:

  1. Squeeze: 通过全剧平均池化层,将每个通道大的二维特征(h*w)压缩为一个实数,维度变化: (C, H, W) -> (C, 1, 1)
  2. Excitation: 给予每个通道的一个特征权重, 然后经过两次全连接层的信息整合提取,构建通道间的自相关性,输出权重数目和特征图通道数一致, 维度变化: (C, 1, 1) -> (C, 1, 1)
  3. Scale: 将归一化后的权重加权道每个通道的特征上, 论文中使用的是相乘加权, 维度变化: (C, H, W) * (C, 1, 1) -> (C, H, W)

pytorch实现:

class SENet(nn.Module):def __init__(self, in_channels, ratio=16):super(SENet, self).__init__()self.in_channels = in_channelsself.fgp = nn.AdaptiveAvgPool2d((1, 1))self.fc1 = nn.Linear(self.in_channels, int(self.in_channels / ratio), bias=False)self.act1 = nn.ReLU()self.fc2 = nn.Linear(int(self.in_channels / ratio), self.in_channels, bias=False)self.act2 = nn.Sigmoid()def forward(self, x):b, c, h, w = x.size()output = self.fgp(x)output = output.view(b, c)output = self.fc1(output)output = self.act1(output)output = self.fc2(output)output = self.act2(output)output = output.view(b, c, 1, 1)return torch.multiply(x, output)
http://www.lryc.cn/news/116043.html

相关文章:

  • go 函数
  • python之正则表达式
  • 【LeetCode每日一题】——219.存在重复元素II
  • 篇六:适配器模式:让不兼容变兼容
  • 【云原生】Docker-compose中所有模块学习
  • 广义积分练习
  • element-ui树形表格,左边勾选,右边显示选中的数据-功能(如动图)
  • Android数字价格变化的动画效果的简单实现
  • Win10无法投影关闭3D模式
  • FFmpeg 编码详细流程
  • 05如何做微服务架构设计
  • 安卓开发问题记录:需要常量表达式
  • 回归预测 | MATLAB实现基于SVM-RFE-BP支持向量机递归特征消除特征选择算法结合BP神经网络的多输入单输出回归预测
  • 配置root账户ssh免密登录并使用docker-machine构建docker服务
  • 【力扣周赛】第357场周赛
  • 多线程案例(4)-线程池
  • 【数据结构OJ题】轮转数组
  • 现代C++中的从头开始深度学习:【4/8】梯度下降
  • Yolov5缺陷检测/目标检测 Jetson nx部署Triton server
  • MobaXterm 中文乱码, 及pojie
  • java: 程序包sun.misc不存在
  • WSL2Linux 子系统(五)
  • java 企业工程管理系统软件源码 自主研发 工程行业适用 em
  • IPO观察丨困于门店扩张的KK集团,还能讲好增长故事吗?
  • 【iOS】RunLoop
  • 数据包传输方式:单播、多播、广播、组播、泛播
  • WebRTC基础知识
  • 积累常见的有针对性的python面试题---python面试题001
  • 在springboot使用websocket时mapper无法注入
  • 前端加密与解密的几种方式