当前位置: 首页 > news >正文

[2025CVPR]一种新颖的视觉与记忆双适配器(Visual and Memory Dual Adapter, VMDA)

引言

多模态目标跟踪(Multi-modal Object Tracking)旨在通过结合RGB模态与其他辅助模态(如热红外、深度、事件数据)来增强可见光传感器的感知能力,尤其在复杂场景下显著提升跟踪鲁棒性。然而,现有方法在频域和时间域的关键线索利用上仍存在不足,导致性能受限。本文提出了一种新颖的视觉与记忆双适配器(Visual and Memory Dual Adapter, VMDA),通过联合建模频域、空间和通道特征,构建更鲁棒的多模态表示,并引入基于人类记忆机制的记忆适配器,有效捕捉全局时间线索。

Fig.1: Framework comparisons between the existing prompt-learning-based tracker and our tracker.(a) Existing trackers propagate temporal cues from adjacent frames and fuse multi-modal features in channel and spatial dimensions.(b) The proposed method integrates a memory adapter to propagate cues adaptively and merge features in channel,spatial, and frequency dimensions.

模型原理

整体框架

本文提出的VMDA框架主要包括四个组件:ViT骨干网络、视觉适配器、记忆适配器和预测头。具体流程如下:

  1. 输入嵌入​:将RGB和辅助模态的模板和搜索区域通过补丁嵌入层转换为令牌。
  2. 浅层特征融合​:使用频率引导多模态融合模块(FMFM)进行初步特征融合。
  3. 时间线索传播​:从多级记忆池中检索时间跟踪线索令牌,并通过记忆滤波器处理后输入ViT块。
  4. 多模态增强与融合​:在每个ViT块后,输出特征通过多模态融合模块(MFM)进行增强和融合,时间跟踪线索则通过记忆滤波器处理。
  5. 最终预测​:经过L层ViT块后,最终令牌用于头部操作生成跟踪结果,并将时间跟踪线索存储在多级记忆池中。

Fig. 3: The framework of the proposed method. We first transform the templates and search region of each modality into tokens, then concatenate them with temporal cue tokens and feed them into the L-layer ViT block. The visual adapter and memory adapter are paralleled with the ViT block. The memory adapter is used to propagate the valuable temporal cues across frames, and the visual adapter is used for modality interaction and fusion. The output features are fed into the prediction head to produce the tracking results.

视觉适配器

视觉适配器的核心在于频率引导多模态融合模块,其设计如下:

频率选择器

频率选择器通过分离高频和低频成分来提取丰富的纹理细节和边缘信息:

其中,Fori​表示输入特征,Fhigh​和Flow​分别表示高频和低频特征。随后,通过全局平均池化和线性层选择和融合不同频率特征:

最终通过元素级加法组合高频和低频成分。

多模态融合模块

多模态融合模块从空间和通道视角整合多模态信息:

通过元素级加法组合三个分支的输出,并通过卷积层生成最终输出。

记忆适配器

记忆适配器由短期、长期和永久记忆组成,通过记忆更新和检索操作实现全局时间线索的传播:

在检索操作中,使用最新时间跟踪线索作为查询选择各层记忆:

并通过元素级加法组合结果,再通过内存滤波器调整。

创新点

  1. 频率引导多模态融合模块​:首次在多模态跟踪中联合建模频域、空间和通道特征,显著提升了跨模态特征融合的效果。
  2. 多级记忆适配器​:借鉴人类记忆机制,设计多级记忆池存储全局时间线索,并通过动态更新和检索操作确保可靠的时间信息传播。
  3. 轻量化适配器设计​:仅微调少量参数,显著降低了训练成本和计算复杂度。

实验结果

数据集与评估指标

实验在RGB-T、RGB-D和RGB-E三个主流多模态跟踪数据集上进行评估:

  • RGB-T跟踪​:RGBT234和LasHeR数据集,使用精度率(PR)和成功率(SR)作为主要指标。
  • RGB-D跟踪​:DepthTrack和VOT22-RGBD数据集,使用精度(Pre)、召回率(Re)、F-score和EAO等指标。
  • RGB-E跟踪​:VisEvent数据集,使用PR和SR作为评估指标。
对比结果

实验结果表明,本文方法在所有数据集上均显著优于现有方法:

  • RGB-T跟踪​:在RGBT234数据集上,PR达到0.919,SR达到0.689;在LasHeR数据集上,PR达到0.726,SR达到0.571。
  • RGB-D跟踪​:在DepthTrack数据集上,Pre、Re和F-score分别为0.636、0.663和0.649;在VOT22-RGBD数据集上,EAO达到0.773,A达到0.821,R达到0.933。
  • RGB-E跟踪​:在VisEvent数据集上,PR达到0.803,SR达到0.626。

Fig. 6: Precision scores of different attributes on the VisEvent test set.

代码

import torch
import torch.nn as nn
import torch.nn.functional as Fclass FrequencySelector(nn.Module):def __init__(self, in_channels, hidden_dim=64):super().__init__()self.conv = nn.Conv2d(in_channels, hidden_dim, 1)self.bn = nn.BatchNorm2d(hidden_dim)self.fc_global = nn.Linear(hidden_dim, hidden_dim)self.fc_high = nn.Linear(hidden_dim, hidden_dim)self.fc_low = nn.Linear(hidden_dim, hidden_dim)def forward(self, x):# 分离高频和低频特征ap_x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)conv_x = self.conv(ap_x)bn_x = self.bn(conv_x)softmax_x = F.softmax(bn_x, dim=1)f_high = x * softmax_xf_low = x - f_high# 全局特征融合f_global = torch.cat([f_high, f_low], dim=1)f_global = F.adaptive_avg_pool2d(f_global, 1).view(f_global.size(0), -1)f_global = self.fc_global(f_global)# 动态加权f_high_gate = torch.sigmoid(self.fc_high(f_global)).unsqueeze(-1).unsqueeze(-1)f_low_gate = torch.sigmoid(self.fc_low(f_global)).unsqueeze(-1).unsqueeze(-1)f_high = f_high * f_high_gatef_low = f_low * f_low_gatereturn f_high + f_lowclass MultiModalFusion(nn.Module):def __init__(self, in_channels):super().__init__()self.conv_rgb = nn.Conv2d(in_channels, in_channels, 1)self.conv_aux = nn.Conv2d(in_channels, in_channels, 1)self.spatial_att = nn.Sequential(nn.Conv2d(in_channels*2, in_channels, 1),nn.Softmax(dim=1))self.channel_att = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(in_channels, in_channels//8, 1),nn.ReLU(),nn.Conv2d(in_channels//8, in_channels, 1),nn.Sigmoid())def forward(self, x_rgb, x_aux):# 空间注意力x_rgb_s = self.conv_rgb(x_rgb)x_aux_s = self.conv_aux(x_aux)concat_s = torch.cat([x_rgb_s, x_aux_s], dim=1)spatial_weight = self.spatial_att(concat_s)x_s_fused = x_rgb_s * spatial_weight[:, :x_rgb_s.size(1)] + \x_aux_s * spatial_weight[:, x_rgb_s.size(1):]# 通道注意力x_concat = torch.cat([x_rgb, x_aux], dim=1)channel_weight = self.channel_att(x_concat)x_c_fused = x_concat * channel_weight# 合并return x_s_fused + x_c_fusedclass MemoryAdapter(nn.Module):def __init__(self, mem_slots=8, token_dim=768):super().__init__()self.mem_slots = mem_slotsself.token_dim = token_dimself.query_proj = nn.Linear(token_dim, token_dim)self.key_proj = nn.Linear(token_dim, token_dim)self.value_proj = nn.Linear(token_dim, token_dim)def forward(self, query, memory_bank):# 计算注意力权重Q = self.query_proj(query).unsqueeze(1)  # [B, 1, D]K = self.key_proj(memory_bank)          # [B, S, D]V = self.value_proj(memory_bank)          # [B, S, D]attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.token_dim))attn_weights = F.softmax(attn_scores, dim=-1)# 加权求和retrieved = torch.matmul(attn_weights, V)  # [B, 1, D]return retrieved.squeeze(1)class VMDATracker(nn.Module):def __init__(self, num_classes=2):super().__init__()# 假设ViT骨干网络已预训练self.vit = VisionTransformer()  # 用户需自行实现或调用预训练模型self.visual_adapter = nn.Sequential(FrequencySelector(in_channels=3),MultiModalFusion(in_channels=3))self.memory_adapter = MemoryAdapter(mem_slots=3, token_dim=768)self.prediction_head = nn.Sequential(nn.Linear(768, 256),nn.ReLU(),nn.Linear(256, num_classes))def forward(self, x_rgb, x_aux, template_tokens):# 多模态特征提取x_rgb = self.visual_adapter(x_rgb)x_aux = self.visual_adapter(x_aux)# 时间线索融合memory_tokens = self.memory_adapter(x_aux, template_tokens)# ViT主干网络fused_tokens = torch.cat([x_rgb, x_aux, memory_tokens], dim=1)vit_output = self.vit(fused_tokens)# 预测头bbox_pred = self.prediction_head(vit_output)return bbox_pred# 使用示例
model = VMDATracker()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()# 假设输入数据格式为 (B, C, H, W)
inputs_rgb = torch.randn(2, 3, 256, 256)
inputs_aux = torch.randn(2, 3, 256, 256)
templates = torch.randn(2, 3, 64, 64)# 前向传播
outputs = model(inputs_rgb, inputs_aux, templates)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()

总结

本文提出了一种基于视觉与记忆双适配器的多模态目标跟踪方法,通过频率引导的多模态融合和多级记忆适配器,显著提升了多模态跟踪的性能。实验结果表明,该方法在RGB-T、RGB-D和RGB-E等多个任务上均达到了最先进的性能。

http://www.lryc.cn/news/581706.html

相关文章:

  • SSL 终结(SSL Termination)深度解析:从原理到实践的全维度指南
  • Python Bcrypt详解:从原理到实战的安全密码存储方案
  • 用户中心Vue3项目开发2.0
  • 2048小游戏实现
  • 线性代数--AI数学基础复习
  • 深度学习6(多分类+交叉熵损失原理+手写数字识别案例TensorFlow)
  • Chunking-free RAG
  • Web-API-day2 间歇函数setInterval与事件监听addEvenListener
  • 【Note】《Kafka: The Definitive Guide》第四章:Kafka 消费者全面解析:如何从 Kafka 高效读取消息
  • Apache Spark 4.0:将大数据分析提升到新的水平
  • A O P
  • 金融级B端页面风控设计:操作留痕与异常预警的可视化方案
  • 深度学习篇---深度学习常见的应用场景
  • 容声W60以光水离子科技实现食材“主动养鲜”
  • [Qt] visual studio code 安装 Qt插件
  • FastAPI + Tortoise-ORM + Aerich 实现数据库迁移管理(MySQL 实践)
  • 深度学习 必然用到的 线性代数知识
  • 嵌入式 数据结构学习(五) 栈与队列的实现与应用
  • React Ref 指南:原理、实现与实践
  • 【PyTorch】PyTorch中torch.nn模块的卷积层
  • 零基础,使用Idea工具写一个邮件报警程序
  • Solidity——什么是状态变量
  • 计算机网络:(七)网络层(上)网络层中重要的概念与网际协议 IP
  • Kafka “假死“现象深度解析与解决方案
  • UI前端大数据可视化进阶:交互式仪表盘的设计与应用
  • 数据驱动实时市场动态监测:让商业决策跑赢时间
  • 【LeetCode 热题 100】240. 搜索二维矩阵 II——排除法
  • 黑马点评系列问题之实战篇02短信登录 利用资料中的mysql语句创建数据表时报错
  • 关于 栈帧变化完整流程图(函数嵌套)
  • Java 双亲委派机制笔记