当前位置: 首页 > news >正文

[Token]ALGM: 基于自适应局部-全局token合并的简单视觉Transformer用于高效语义分割, CVPR2024

ALGM: Adaptive Local-then-Global Token Merging for Efficient Semantic Segmentation with Plain Vision Transformers
paper|code

Background & Motivation

具有高余弦相似度的token可以合并,而不会降低分割质量。

  1. CTS表明,在早期网络阶段进行局部token共享可以提高效率,而不会影响分割质量,但它需要一个预处理网络。 因此,我们的第一个目标是在网络浅层合并冗余符元,而无需预处理,同时保持分割质量。
  2. 像ToMe这样的token合并方法表明,逐渐合并整张图像上的冗余token可以大大提高效率,但全局范围内合并损害分割质量。 因此,我们的第二个目标是应用全局token合并以进一步提高效率,同时不会损害分割质量。

Challenge

如何创造一个新方法,既能像CTS一样在早期就合并局部Token,又能像ToMe一样在全局范围内高效合并,同时没有额外的预处理网络,不损害分割质量。

沿用余弦相似度的标准,发现随着模型加深:

  1. 在早期,它足以在局部区分开不同物体
  2. 在后期,它能在全局上更清晰地区分不同物体

Method

基于这些发现,提出了Adaptive Local-then-Global Merging (ALGM) module,该模块集成了两个token合并阶段。在第一网络层中,ALGM 采用局部合并策略。 在中间层采用全局合并机制,以减少全局token冗余。 此外,不预设token的合并数量,而是根据图像内容的语义复杂度动态决定合并token的数量。
在这里插入图片描述

Token相似度分析

在何种情况下以及何时,余弦相似度能够成为一种有效的指标,用于识别代表同一类别的标记,从而使其适合进行局部和全局合并。

提取并比较了分词器生成的token与在 ADE20K训练集中训练的 ViT-S的相似性。
(1)首先,分析了第一层转置前向层中k×k窗口内的局部相似性。如图 2a 所示,窗口大小 k 越小,余弦相似度就越能准确地反映token属于同一类别。因此,在第一层中,在小局部窗口内具有高余弦相似度的token很可能可以合并,而不会导致分割质量下降。
(2)计算整张图像中所有 Transformer 层的类别间和类别内token的余弦相似度来分析全局相似度。如图 2b 所示,早期层中的全局相似度并不能准确反映类别对应关系,因此不应将其用于识别需要合并的token。然而,在网络更深的部分,余弦相似度成为一种更好的衡量标准,可以用于在全局范围内识别可以合并的标记,而不会影响分割质量。
在这里插入图片描述

Adaptive Local-then-Global Merging(ALGM)

(a)早期层中的局部token相似性以及(b)中间层中的全局token相似性很可能是衡量token合并能力的指标。提出自适应局部-然后全局合并(ALGM)方法。首先在第一层使用条件局部平均池化(CLAP)模块进行局部合并。在中间层,采用基于 BSM算法的全局二分合并(GBM)模块进行全局合并。整个过程以一个token解合并模块结束,以恢复原始的token解析。

Local token merging.

在这里插入图片描述

如果一个Token和它在一个小窗口内的邻居们高度相似,就将它们合并。CLAP模块,它被放置在第一层(L1)的MHSA和MLP模块之间,用来实现这个功能。
Step 1.
它接收来自第一层(L1)的Token T’1,并将其重新排列成一个空间网格 T’G1。然后,定义k×k大小的窗口,并将每个窗口内的Tokens分组到不同的集合W中。
Step 2.
计算小组内所有Token之间的余弦相似度,并求出这些相似度的平均值μw。然后,根据相似度代表可合并性的假设,CLAP模块只合并那些平均相似度μwμwμw大于阈值τττ的窗口。
Step 3.
被选中的窗口www内的所有Tokens,通过计算它们的平均值,合并成一个Token。这些被合并的Tokens的原始索引也会被存储起来,以备后续的“解合并”(unmerging)操作。完成后,合并产生的新Token和那些未被合并的Token被连接在一起,生成最终的输出,其数量小于或等于原始数量。

Global token merging.

在这里插入图片描述
Step 1.
token分组与图构建,分成两组,构造二分图
Step 2.
寻找最佳匹配,找到唯一一个最合适的合并对象。
Step 3.
应用相似度阈值,保证足够相似的token对才被允许进入最后的合并阶段。
Step 4.
对于所有经过前两轮筛选后仍然保留下来的边,其连接的token对将被合并,并且存储索引 。所有未参与合并的token和那些合并后更新了的token被拼接在一起,形成一个新的、数量更少的token集合,作为下一层的输入。

Token unmerging.

利用合并时记录的索引信息,通过“复制粘贴”的方式,将被合并的Token还原到其原始位置,从而恢复出与输入图像同样尺寸的特征表示。
这个过程的执行时机取决于下游的解码器:
如果解码器是Transformer(不怕乱序),就先解码,后还原,效率更高。
如果解码器是CNN(要求整齐),就先还原,后解码,以满足其输入要求。

Adaptive token merging.

在训练之前,使用想要应用 ALGM 的基础分割模型,并在训练集中进行对比测试。然后,在每一层 Ll 中提取 MHSA 块之后的token,计算所有token对之间的余弦相似度,并计算整个训练集的平均相似度 µsimµsimµsim 和标准差 σsimσsimσsim。根据这些统计数据,设置阈值 τ=µsim+σsimτ = µsim + σsimτ=µsim+σsim。使用此阈值,经过 CLAP 和 GBM 模块后的剩余令牌数量 N’ 和 N’’ 会因图像而异。在训练过程中,为了便于对图像和标记进行分组处理,确定每次分组的最大剩余token数量 N’ 和 N’',然后将这些数值应用于该批次中的所有图像。

算法实现

local_merge.py

code


import math
from typing import Callable, Tupleimport torch
import torch.nn
import torch.nn.functional as F
from einops import rearrange
import numpy as npdef conditional_pooling(feat: torch.Tensor,threshold:float,window_size: Tuple[int, int],
) -> Tuple[Callable, Callable]:with torch.no_grad():ws_h, ws_w = int(window_size[0]), int(window_size[1])stride_h, stride_w = ws_h, ws_wnum_token_window = stride_h * stride_wx_cls, feat = feat[:, :1, :], feat[:, 1:, :]B, N, D = feat.size()base_grid_H = int(math.sqrt(N))base_grid_W = base_grid_Hassert base_grid_H * base_grid_W == N and base_grid_H % ws_h == 0 and base_grid_W % ws_w == 0feat = rearrange(feat, "b (h w) c -> b c h w", h=base_grid_H)feat = rearrange(feat, 'b c (gh ps_h) (gw ps_w) -> b gh gw c ps_h ps_w', gh=base_grid_H//ws_h, gw=base_grid_W//ws_w)b, gh, gw, c, ps_h, ps_w = feat.shape# Flatten mxm window for pairwise operationstensor_flattened = feat.reshape(b, gh, gw, c, -1)# Expand dims for pairwise operationstensor_1 = tensor_flattened.unsqueeze(-1)tensor_2 = tensor_flattened.unsqueeze(-2)# Compute cosine similaritiessims = F.cosine_similarity(tensor_1, tensor_2, dim=3)# Exclude the self-similarity (i.e., similarity with oneself will be 1)sims_mask = 1 - torch.eye(ps_h * ps_w).to(sims.device)sims = sims * sims_mask# Average similarities (excluding the self-similarity)similarity_map = sims.sum(-1).sum(-1) / ((ps_h * ps_w) * (ps_h * ps_w - 1))similarity_map = rearrange(similarity_map.unsqueeze(1), 'b c h w-> b (c h w)')#--- adaptive section ---#n_B, n_H = similarity_map.shapenode_mean = torch.tensor(threshold).cuda(sims.device)node_mean=node_mean.repeat(1,n_H)r = torch.ge(similarity_map, node_mean).sum(dim=1).min()# -------------# #   get top k similar super patches _, sim_super_patch_idxs = similarity_map.topk(r,dim=-1)# --- creating the mergabel and unmergable super  pathestensor = torch.arange(base_grid_H * base_grid_W).reshape(base_grid_H, base_grid_W).to(feat.device)# Repeat the tensor to create a batch of size 2tensor = tensor.unsqueeze(0).repeat(B, 1, 1)# Apply unfold operation on last two dimensions to create the sliding windowwindowed_tensor = tensor.unfold(1, ws_h, stride_h).unfold(2, ws_w, stride_w)# Reshape the tensor to the desired shape windowed_tensor = windowed_tensor.reshape(B, -1, num_token_window)# Use torch.gather to collect the desired elementsgathered_tensor = torch.gather(windowed_tensor, 1, sim_super_patch_idxs.unsqueeze(-1).expand(-1, -1, num_token_window))# Create a mask for all indices, for each batchmask = torch.ones((B, windowed_tensor.shape[1]), dtype=bool).to(feat.device)# Create a tensor that matches the shape of indices and fill it with Falsemask_values = torch.zeros_like(sim_super_patch_idxs, dtype=torch.bool).to(feat.device)# Use scatter_ to update the mask. This will set mask[b, indices[b]] = False for all bmask.scatter_(1, sim_super_patch_idxs, mask_values)# Get the remaining tensorremaining_tensor = windowed_tensor[mask.unsqueeze(-1).expand(-1, -1, num_token_window)].reshape(B, -1, num_token_window)unm_idx = remaining_tensor.reshape(B, -1).sort(dim=-1).values.unsqueeze(-1)dim_index = (num_token_window)- 1 src_idx= gathered_tensor[:, :, :dim_index].reshape(B, -1).unsqueeze(-1)dst_idx= gathered_tensor[:, :, dim_index].reshape(B, -1).unsqueeze(-1)merge_idx = torch.arange(src_idx.shape[1]//dim_index).repeat_interleave(dim_index).repeat(B, 1).unsqueeze(-1).to(feat.device)def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:# TODO: num_token_window can be undefinedx_cls , x_feat =  x[:, :1, :], x[:, 1:, :]n, t1, c = x_feat.shapesrc = x_feat.gather(dim=-2, index=src_idx.expand(n, r*dim_index, c))dst = x_feat.gather(dim=-2, index=dst_idx.expand(n, r, c))unm = x_feat.gather(dim=-2, index=unm_idx.expand(n, t1 - (r*num_token_window), c))dst = dst.scatter_reduce(-2, merge_idx.expand(n,r*dim_index, c), src, reduce=mode)x = torch.cat([dst, unm], dim=1)x = torch.cat((x_cls, x), dim=1)return xreturn mergedef merge_wavg(merge: Callable, x: torch.Tensor, size: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:if size is None:size = torch.ones_like(x[..., 0, None])x = merge(x * size, mode="sum")size = merge(size, mode="sum")    x = x / sizereturn x, sizedef merge_source(merge: Callable, x: torch.Tensor, source: torch.Tensor = None
) -> torch.Tensor:if source is None:n, t, _ = x.shapesource = torch.eye(t, device=x.device)[None, ...].expand(n, t, t)source = merge(source, mode="amax")return source

global_merge.py

code

import math
from typing import Callable, Tuple
import torchdef do_nothing(x, mode=None):return xdef turbo_matching(metric: torch.Tensor,layer_idx:int,source: torch.Tensor,class_token: bool = False,distill_token: bool = False,
) -> Tuple[Callable, Callable]:protected = 0if class_token:protected += 1if distill_token:protected += 1t = metric.shape[1]r = (t - protected) // 2if r <= 0:return do_nothing, do_nothingwith torch.no_grad():B,m_t,um_t = source.shapemetric = metric / metric.norm(dim=-1, keepdim=True)a, b = metric[..., ::2, :], metric[..., 1::2, :]scores = a @ b.transpose(-1, -2)if class_token:scores[..., 0, :] = -math.infif distill_token:scores[..., :, 0] = -math.infnode_max, node_idx = scores.max(dim=-1)edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]# ------------------ start  addaptive section --------- i = layer_idxn_B, n_H = node_max.shapenode_mean= torch.add(node_max[:,1:].mean(dim=1).mean(),node_max[:,1:].std(dim=1).mean()/i)node_mean=node_mean.repeat(1,n_H)r = torch.ge(node_max, node_mean).sum(dim=1).min()# ------------------ end addaptive section --------- unm_idx = edge_idx[..., r:, :]  # Unmerged Tokenssrc_idx = edge_idx[..., :r, :]  # Merged Tokensdst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)if class_token:# Sort to ensure the class token is at the startunm_idx = unm_idx.sort(dim=1)[0]def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:src, dst = x[..., ::2, :], x[..., 1::2, :]n, t1, c = src.shapeunm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))src = src.gather(dim=-2, index=src_idx.expand(n, r, c))dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)if distill_token:return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)else:return torch.cat([unm, dst], dim=1)return merge

Apply ALGM between the attention and mlp blocks

code

class TurboBlock(Block):"""Modifications:- Apply ALGM between the attention and mlp blocks"""def _drop_path1(self, x):return self.drop_path1(x) if hasattr(self, "drop_path1") else self.drop_path(x)def _drop_path2(self, x):return self.drop_path2(x) if hasattr(self, "drop_path2") else self.drop_path(x)def forward(self, x: torch.Tensor ) -> torch.Tensor:attn_size = self._turbo_info["size"] if self._turbo_info["prop_attn"] else Nonex_attn, metric  = self.attn(self.norm1(x),attn_size)x =  x + self._drop_path1(x_attn)layer_idx = self._turbo_info["selected_layers"].pop(0)if self._turbo_info["source"] is None: # if layer_idx == 1:merge  = conditional_pooling(x,self._turbo_info["threshold"],self._turbo_info["window_size"],)if self._turbo_info["trace_source"]:self._turbo_info["source"] = merge_source(merge, x, self._turbo_info["source"])x, self._turbo_info["size"] = merge_wavg(merge, x, self._turbo_info["size"])else:merge = turbo_matching(x,layer_idx,self._turbo_info["source"],self._turbo_info["class_token"],self._turbo_info["distill_token"],)if self._turbo_info["trace_source"]:self._turbo_info["source"] = merge_source(merge, x, self._turbo_info["source"])x, self._turbo_info["size"] = merge_wavg(merge, x, self._turbo_info["size"])x = x + self._drop_path2(self.mlp(self.norm2(x)))return x 

实验结果

在这里插入图片描述
在这里插入图片描述

Inspire

  1. local划分、合并的策略是否在low-level像素级任务上是有效的,替代window attention(复杂度)
http://www.lryc.cn/news/603795.html

相关文章:

  • docker docker与swarm入门笔记
  • Python中的决策树机器学习模型简要介绍和代码示例(基于sklearn)
  • Unity_SRP Batcher
  • 谷歌采用 Ligero 构建其 ZK 技术栈
  • 【密码学】4. 分组密码
  • ftp加ssl,升级ftps
  • WebRTC(十四):WebRTC源码编译与管理
  • 7月29日星期二今日早报简报微语报早读
  • TCPDump实战手册:协议/端口/IP过滤与组合分析指南
  • Kruskal算法
  • 《林景媚与命运共创者》
  • 暑期算法训练.10
  • Spring Boot中的this::语法糖详解
  • 解锁全球数据:Bright Data MCP 智能解决代理访问难题
  • pnpm 入门与实践指南
  • Element Plus常见基础组件(二)
  • React 图标库发布到 npm 仓库
  • Linux -- 文件【中】
  • 基于深度学习的医学图像分析:使用CycleGAN实现图像到图像的转换
  • tcp通讯学习数据传输
  • DETR 下 Transformer 应用探讨
  • 准大一GIS专业新生,如何挑选电脑?
  • 站点到站点-主模式
  • Java 11 新特性详解与代码示例
  • JAVA中集合的遍历方式
  • 【C++】1. C++基础知识
  • 编辑距离:理论基础、算法演进与跨领域应用
  • taro+react重新给userInfo赋值后,获取的用户信息还是老用户信息
  • ERROR c.a.c.n.c.NacosPropertySourceBuilder
  • react 的 useTransition 、useDeferredValue