当前位置: 首页 > news >正文

【Block总结】HiLo注意力,局部自注意力捕获细粒度的高频信息,通过全局注意力捕获低频信息|即插即用

一、论文信息

  • 标题: Fast Vision Transformers with HiLo Attention
  • GitHub链接: https://github.com/ziplab/LITv2
  • 论文链接: arXiv
    在这里插入图片描述

二、创新点

  • HiLo注意力机制: 本文提出了一种新的自注意力机制——HiLo注意力,旨在同时捕捉图像中的高频和低频特征。该机制通过将自注意力分为两个分支,分别处理高频(Hi-Fi)和低频(Lo-Fi)信息,从而提高计算效率和模型性能[1][5][16]。

  • LITv2模型: 基于HiLo注意力机制,LITv2模型在多个计算机视觉任务上表现优越,尤其是在处理高分辨率图像时,显著提升了速度和准确性[1][5][16]。

  • 相对位置编码优化: 采用3×3的深度卷积层替代传统的固定相对位置编码,进一步加快了密集预测任务的训练和推理速度[1][5][16]。

三、方法

  • 整体架构: LITv2模型分为多个阶段,生成金字塔特征图,适用于密集预测任务。模型通过局部窗口自注意力捕捉细节,同时使用全局自注意力处理低频信息,确保性能与效率的平衡[1][5][16]。

  • 特征处理: 输入图像被切分为固定大小的图像块(patch),每个patch通过线性变换映射到高维特征空间。HiLo注意力机制在每个Transformer模块中使用标准的残差连接和LayerNorm层,以稳定训练并保持特征传递[1][5][16]。
    在这里插入图片描述

四、效果

  • 性能提升: LITv2在标准基准测试中表现优于大多数现有的视觉Transformer模型,尤其在处理高分辨率图像时,HiLo机制在CPU上比传统的局部窗口注意力机制快1.6倍,比空间缩减注意力机制快1.4倍[1][5][16]。

  • 计算效率: 通过将注意力机制分为高频和低频,LITv2能够有效减少计算量,同时保持或提升模型的准确性和速度[1][5][16]。

五、实验结果

  • 基准测试: 论文中通过实际平台的速度评估,展示了LITv2在GPU和CPU上的优越性能。实验结果表明,HiLo注意力机制在多个视觉任务中均表现出色,尤其是在图像分类和物体检测任务中[1][5][16]。

  • FLOPs与吞吐量: 研究表明,HiLo机制在FLOPs、吞吐量和内存消耗方面均优于现有的注意力机制,证明了其在实际应用中的有效性[1][5][16]。

六、总结

Fast Vision Transformers with HiLo Attention通过引入HiLo注意力机制,成功地将高频和低频信息的处理分开,显著提升了视觉Transformer的性能和效率。LITv2模型在多个计算机视觉任务中表现优异,展示了其在实际应用中的潜力。该研究为未来的视觉模型设计提供了新的思路,尤其是在处理高分辨率图像时的计算效率和准确性方面[1][5][16]。

七、代码

import os
import torch
import torch.nn as nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import mathclass HiLo(nn.Module):"""HiLo AttentionLink: https://arxiv.org/abs/2205.13213"""def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=2,alpha=0.5):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."head_dim = int(dim / num_heads)self.dim = dim# self-attention heads in Lo-Fiself.l_heads = int(num_heads * alpha)# token dimension in Lo-Fiself.l_dim = self.l_heads * head_dim# self-attention heads in Hi-Fiself.h_heads = num_heads - self.l_heads# token dimension in Hi-Fiself.h_dim = self.h_heads * head_dim# local window size. The `s` in our paper.self.ws = window_sizeif self.ws == 1:# ws == 1 is equal to a standard multi-head self-attentionself.h_heads = 0self.h_dim = 0self.l_heads = num_headsself.l_dim = dimself.scale = qk_scale or head_dim ** -0.5# Low frequence attention (Lo-Fi)if self.l_heads > 0:if self.ws != 1:self.sr = nn.AvgPool2d(kernel_size=window_size, stride=window_size)self.l_q = nn.Linear(self.dim, self.l_dim, bias=qkv_bias)self.l_kv = nn.Linear(self.dim, self.l_dim * 2, bias=qkv_bias)self.l_proj = nn.Linear(self.l_dim, self.l_dim)# High frequence attention (Hi-Fi)if self.h_heads > 0:self.h_qkv = nn.Linear(self.dim, self.h_dim * 3, bias=qkv_bias)self.h_proj = nn.Linear(self.h_dim, self.h_dim)def hifi(self, x):B, H, W, C = x.shapeh_group, w_group = H // self.ws, W // self.wstotal_groups = h_group * w_groupx = x.reshape(B, h_group, self.ws, w_group, self.ws, C).transpose(2, 3)qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.h_dim // self.h_heads).permute(3, 0, 1,4, 2, 5)q, k, v = qkv[0], qkv[1], qkv[2]  # B, hw, n_head, ws*ws, head_dimattn = (q @ k.transpose(-2, -1)) * self.scale  # B, hw, n_head, ws*ws, ws*wsattn = attn.softmax(dim=-1)attn = (attn @ v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, self.h_dim)x = attn.transpose(2, 3).reshape(B, h_group * self.ws, w_group * self.ws, self.h_dim)x = self.h_proj(x)return xdef lofi(self, x):B, H, W, C = x.shapeq = self.l_q(x).reshape(B, H * W, self.l_heads, self.l_dim // self.l_heads).permute(0, 2, 1, 3)if self.ws > 1:x_ = x.permute(0, 3, 1, 2)x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)kv = self.l_kv(x_).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)else:kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)k, v = kv[0], kv[1]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.l_dim)x = self.l_proj(x)return xdef forward(self, x):B, N, C = x.shapeH = W = int(N ** 0.5)x = x.reshape(B, H, W, C)if self.h_heads == 0:x = self.lofi(x)return x.reshape(B, N, C)if self.l_heads == 0:x = self.hifi(x)return x.reshape(B, N, C)hifi_out = self.hifi(x)lofi_out = self.lofi(x)x = torch.cat((hifi_out, lofi_out), dim=-1)x = x.reshape(B, N, C)return xdef flops(self, N):H = int(N ** 0.5)# when the height and width cannot be divided by ws, we pad the feature map in the same way as Swin Transformer for object detection/segmentationHp = Wp = self.ws * math.ceil(H / self.ws)Np = Hp * Wp# For Hi-Fi# qkvhifi_flops = Np * self.dim * self.h_dim * 3nW = Np / self.ws / self.wswindow_len = self.ws * self.ws# q @ k and attn @ vwindow_flops = window_len * window_len * self.h_dim * 2hifi_flops += nW * window_flops# projectionhifi_flops += Np * self.h_dim * self.h_dim# for Lo-Fi# qlofi_flops = Np * self.dim * self.l_dim# H = int(Np ** 0.5)kv_len = (Hp // self.ws) ** 2# k, vlofi_flops += kv_len * self.dim * self.l_dim * 2# q @ k and attn @ vlofi_flops += Np * self.l_dim * kv_len * 2# projectionlofi_flops += Np * self.l_dim * self.l_dimreturn hifi_flops + lofi_flopsif __name__ == "__main__":dim=256# 如果GPU可用,将模块移动到 GPUdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 输入张量 (batch_size, height, width,channels)x = torch.randn(1,40*40,dim).to(device)# 初始化 HWD 模块block = HiLo(dim)print(block)block = block.to(device)# 前向传播output = block(x)print("输入:", x.shape)print("输出:", output.shape)

输出结果:
在这里插入图片描述

http://www.lryc.cn/news/529683.html

相关文章:

  • python 使用Whisper模型进行语音翻译
  • C# Winform enter键怎么去关联button
  • Github 2025-01-30 Go开源项目日报 Top10
  • 电路研究9.2.6——合宙Air780EP中HTTP——HTTP GET 相关命令使用方法研究
  • Java手写简单Merkle树
  • DeepSeek的使用技巧介绍
  • 19 压测和常用的接口优化方案
  • AI应用部署——streamlit
  • NLP自然语言处理通识
  • C++ 6
  • 使用QSqlQueryModel创建交替背景色的表格模型
  • jinfo命令详解
  • 如何在 ACP 中建模复合罐
  • 【Java】微服务找不到问题记录can not find user-service
  • 基于Hutool的Merkle树hash值生成工具
  • Windows系统本地部署deepseek 更改目录
  • 深度学习篇---数据存储类型
  • 可被electron等调用的Qt截图-录屏工具【源码开放】
  • electron 应用开发实践
  • openssl 生成证书 windows导入证书
  • 程序员学英文之At the Airport Customs
  • 字节iOS面试经验分享:HTTP与网络编程
  • 游戏引擎 Unity - Unity 启动(下载 Unity Editor、生成 Unity Personal Edition 许可证)
  • 前端八股CSS:盒模型、CSS权重、+与~选择器、z-index、水平垂直居中、左侧固定,右侧自适应、三栏均分布局
  • Linux网络 | 网络层IP报文解析、认识网段划分与IP地址
  • 服务器虚拟化实战:架构、技术与最佳实践
  • (leetcode 213 打家劫舍ii)
  • [C语言日寄] <stdio.h> 头文件功能介绍
  • 一文读懂 Faiss:开启高维向量高效检索的大门
  • 【二叉搜索树】