当前位置: 首页 > news >正文

【Block总结】CoT,上下文Transformer注意力|即插即用

一. 论文信息

  • 标题: Contextual Transformer Networks for Visual Recognition
  • 论文链接: arXiv
  • GitHub链接: https://github.com/JDAI-CV/CoTNet
    在这里插入图片描述

二. 创新点

  • 上下文Transformer模块(CoT): 提出了CoT模块,能够有效利用输入键之间的上下文信息,指导动态注意力矩阵的学习,从而增强视觉表示能力。

  • 静态与动态上下文结合: CoT模块通过3×3卷积生成静态上下文表示,并结合动态注意力机制,提升了模型的特征提取能力。

三. 方法

CoT模块的设计流程如下:

  1. 上下文编码: 使用3×3卷积对输入的键进行上下文编码,生成静态上下文表示。

  2. 动态注意力矩阵学习: 将静态上下文与输入查询拼接,通过两个1×1卷积学习动态多头注意力矩阵。

  3. 动态上下文表示生成: 将学习到的注意力矩阵与输入值相乘,生成动态上下文表示。

  4. 输出融合: 最后,将静态和动态上下文表示融合,作为CoT模块的输出。

这种设计使得CoT模块可以替代ResNet架构中的每个3×3卷积,形成一种新的Transformer样式主干网络,称为上下文Transformer网络(CoTNet)。
在这里插入图片描述

CoT模块

CoT(Contextual Transformer)模块是一种新颖的Transformer风格模块,旨在增强视觉识别能力。它通过充分利用输入键之间的上下文信息,指导动态注意力矩阵的学习,从而提升模型的特征表示能力。CoT模块可以直接替换传统卷积网络中的3×3卷积,形成一种新的上下文Transformer网络(CoTNet)。

1. 工作原理

CoT模块的工作流程如下:

  1. 静态上下文表示生成:

    • 输入特征通过3×3卷积进行处理,生成静态上下文表示。
  2. 动态注意力矩阵生成:

    • 将静态上下文与输入查询拼接,经过两个1×1卷积生成动态注意力矩阵。
  3. 动态上下文表示生成:

    • 使用学习到的注意力矩阵对输入值进行加权,生成动态上下文表示。
  4. 输出融合:

    • 将静态上下文表示和动态上下文表示相加,形成最终输出。

2. 创新点

  • 上下文编码: CoT模块首先通过3×3卷积对输入的键进行上下文编码,生成静态上下文表示。这一过程确保了模型能够捕捉到局部邻域内的特征信息。

  • 动态注意力学习: 将静态上下文与输入查询拼接后,通过两个1×1卷积学习动态多头注意力矩阵。这个动态矩阵能够根据输入的变化调整注意力分配,从而更好地捕捉特征之间的关系。

  • 融合静态与动态上下文: 最终,CoT模块将静态和动态上下文表示融合,作为输出。这种设计使得模型能够同时利用静态信息和动态信息,增强了特征提取的能力。

CoT模块通过创新的上下文编码和动态注意力学习机制,显著提升了视觉识别模型的性能。其设计不仅增强了模型的特征提取能力,还为未来的计算机视觉研究提供了新的思路和方法。CoT模块的灵活性使其能够轻松集成到现有的卷积神经网络架构中,推动了视觉识别技术的发展。

四. 效果

CoTNet在多个计算机视觉任务中表现出色,尤其是在图像识别、目标检测和实例分割等任务中,展现了其作为主干网络的强大能力。

五. 实验结果

  • 在ImageNet数据集上的表现: CoTNet模型在Top-1准确率和Top-5准确率上均超过了传统的卷积神经网络(CNN)架构,展示了更好的推理时间与准确率的平衡。

  • 在开放世界图像分类挑战中的表现: CoTNet在CVPR 2021的开放世界图像分类挑战中获得了第一名,证明了其在实际应用中的有效性。

六. 总结

上下文Transformer网络(CoTNet)通过创新的CoT模块,成功地将上下文信息的动态聚合与静态聚合结合,显著提升了视觉识别任务的性能。实验结果表明,CoTNet在多个基准数据集上均表现优异,为计算机视觉领域提供了一种新的有效方法。该模块的设计不仅提升了模型的准确性,还为未来的研究提供了新的思路。

代码

import torch
import torch.nn.functional
import torch.nn.functional as F
from torch import nn
import mathclass CoTAttention(nn.Module):def __init__(self, dim=512, kernel_size=3):super().__init__()self.dim = dimself.kernel_size = kernel_sizeself.key_embed = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),nn.BatchNorm2d(dim),nn.ReLU())self.value_embed = nn.Sequential(nn.Conv2d(dim, dim, 1, bias=False),nn.BatchNorm2d(dim))factor = 4self.attention_embed = nn.Sequential(nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),nn.BatchNorm2d(2 * dim // factor),nn.ReLU(),nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1))def forward(self, x):bs, c, h, w = x.shapek1 = self.key_embed(x)  # bs,c,h,wv = self.value_embed(x).view(bs, c, -1)  # bs,c,h,wy = torch.cat([k1, x], dim=1)  # bs,2c,h,watt = self.attention_embed(y)  # bs,c*k*k,h,watt = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)att = att.mean(2, keepdim=False).view(bs, c, -1)  # bs,c,h*wk2 = F.softmax(att, dim=-1) * vk2 = k2.view(bs, c, h, w)return k1 + k2if __name__ == "__main__":dim=64# 如果GPU可用,将模块移动到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, channels,height, width)x = torch.randn(2,dim,40,40).to(device)# 初始化 CoTAttention模块block = CoTAttention(dim,3) # kernel_size为height或者widthprint(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)

输出结果:
在这里插入图片描述

http://www.lryc.cn/news/530994.html

相关文章:

  • linux库函数 gettimeofday() localtime的概念和使用案例
  • 编程题-电话号码的字母组合(中等)
  • EasyExcel使用详解
  • 基于“蘑菇书”的强化学习知识点(二):强化学习中基于策略(Policy-Based)和基于价值(Value-Based)方法的区别
  • 民法学学习笔记(个人向) Part.2
  • 物业管理系统源码驱动社区管理革新提升用户满意度与服务效率
  • 租房管理系统助力数字化转型提升租赁服务质量与用户体验
  • Ollama教程:轻松上手本地大语言模型部署
  • Baklib推动数字化内容管理解决方案助力企业数字化转型
  • DeepSeek-R1 论文. Reinforcement Learning 通过强化学习激励大型语言模型的推理能力
  • DOM 操作入门:HTML 元素操作与页面事件处理
  • 使用 HTTP::Server::Simple 实现轻量级 HTTP 服务器
  • C++滑动窗口技术深度解析:核心原理、高效实现与高阶应用实践
  • 基于构件的软件开发方法
  • 网站快速收录:如何设置robots.txt文件?
  • OpenGL学习笔记(六):Transformations 变换(变换矩阵、坐标系统、GLM库应用)
  • 8.攻防世界Web_php_wrong_nginx_config
  • 【优先算法】专题——位运算
  • qt.qpa.plugin: Could not find the Qt platform plugin “dxcb“ in ““
  • 1-刷力扣问题记录
  • 物联网 STM32【源代码形式-使用以太网】连接OneNet IOT从云产品开发到底层MQTT实现,APP控制 【保姆级零基础搭建】
  • 【单层神经网络】基于MXNet的线性回归实现(底层实现)
  • unity中的动画混合树
  • 《基于deepseek R1开源大模型的电子数据取证技术发展研究》
  • Potplayer常用快捷键
  • C++ Primer 自定义数据结构
  • 35.Word:公积金管理中心文员小谢【37】
  • 北京钟鼓楼:立春“鞭春牛”,钟鼓迎春来
  • 股票入门知识
  • Java自定义IO密集型和CPU密集型线程池