当前位置: 首页 > news >正文

UNet改进(4):交叉注意力(Cross Attention)-多模态/多特征交互

在计算机视觉领域,UNet因其优异的性能在图像分割任务中广受欢迎。本文将介绍一种改进的UNet架构——UNetWithCrossAttention,它通过引入交叉注意力机制来增强模型的特征融合能力。

1. 交叉注意力机制

交叉注意力(Cross Attention)是一种让模型能够动态地从辅助特征中提取相关信息来增强主特征的机制。在我们的实现中,CrossAttention类实现了这一功能:

class CrossAttention(nn.Module):def __init__(self, channels):super(CrossAttention, self).__init__()self.query_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(channels, channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x1, x2):batch_size, C, height, width = x1.size()# 投影到query, key, value空间proj_query = self.query_conv(x1).view(batch_size, -1, height * width).permute(0, 2, 1)proj_key = self.key_conv(x2).view(batch_size, -1, height * width)proj_value = self.value_conv(x2).view(batch_size, -1, height * width)# 计算注意力图energy = torch.bmm(proj_query, proj_key)attention = torch.softmax(energy / math.sqrt(proj_key.size(-1)), dim=-1)# 应用注意力out = torch.bmm(proj_value, attention.permute(0, 2, 1))out = out.view(batch_size, C, height, width)# 残差连接out = self.gamma * out + x1return out

该模块的工作原理是:

  1. 将主特征x1投影为query,辅助特征x2投影为key和value

  2. 计算query和key的相似度得到注意力权重

  3. 使用注意力权重对value进行加权求和

  4. 通过残差连接将结果与原始主特征融合

2. 双卷积模块

DoubleConv是UNet中的基础构建块,包含两个连续的卷积层,并可选择性地加入交叉注意力:

class DoubleConv(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(DoubleConv, self).__init__()self.use_cross_attention = use_cross_attentionself.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if use_cross_attention:self.cross_attention = CrossAttention(out_channels)def forward(self, x, aux_feature=None):x = self.conv1(x)x = self.conv2(x)if self.use_cross_attention and aux_feature is not None:x = self.cross_attention(x, aux_feature)return x

3. 下采样和上采样模块

下采样模块Down结合了最大池化和双卷积:

class Down(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),DoubleConv(in_channels, out_channels, use_cross_attention))def forward(self, x, aux_feature=None):return self.downsampling[1](self.downsampling[0](x), aux_feature)

上采样模块Up使用转置卷积进行上采样并拼接特征:

class Up(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Up, self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels, use_cross_attention)def forward(self, x1, x2, aux_feature=None):x1 = self.upsampling(x1)x = torch.cat([x2, x1], dim=1)x = self.conv(x, aux_feature)return x

4. 完整的UNetWithCrossAttention架构

将上述模块组合起来,我们得到了完整的UNetWithCrossAttention:

class UNetWithCrossAttention(nn.Module):def __init__(self, in_channels=1, num_classes=1, use_cross_attention=False):super(UNetWithCrossAttention, self).__init__()self.in_channels = in_channelsself.num_classes = num_classesself.use_cross_attention = use_cross_attention# 编码器self.in_conv = DoubleConv(in_channels, 64, use_cross_attention)self.down1 = Down(64, 128, use_cross_attention)self.down2 = Down(128, 256, use_cross_attention)self.down3 = Down(256, 512, use_cross_attention)self.down4 = Down(512, 1024, use_cross_attention)# 解码器self.up1 = Up(1024, 512, use_cross_attention)self.up2 = Up(512, 256, use_cross_attention)self.up3 = Up(256, 128, use_cross_attention)self.up4 = Up(128, 64, use_cross_attention)self.out_conv = OutConv(64, num_classes)def forward(self, x, aux_feature=None):# 编码过程x1 = self.in_conv(x, aux_feature)x2 = self.down1(x1, aux_feature)x3 = self.down2(x2, aux_feature)x4 = self.down3(x3, aux_feature)x5 = self.down4(x4, aux_feature)# 解码过程x = self.up1(x5, x4, aux_feature)x = self.up2(x, x3, aux_feature)x = self.up3(x, x2, aux_feature)x = self.up4(x, x1, aux_feature)x = self.out_conv(x)return x

5. 应用场景与优势

这种带有交叉注意力的UNet架构特别适合以下场景:

  1. 多模态图像分割:当有来自不同成像模态的辅助信息时,交叉注意力可以帮助模型有效地融合这些信息

  2. 时序图像分析:对于视频序列,前一帧的特征可以作为辅助特征来增强当前帧的分割

  3. 弱监督学习:当有额外的弱监督信号时,可以通过交叉注意力将其融入主网络

相比于传统UNet,这种架构的优势在于:

  • 能够动态地关注辅助特征中最相关的部分

  • 通过注意力机制实现更精细的特征融合

  • 保留了UNet原有的多尺度特征提取能力

  • 通过残差连接避免了信息丢失

6. 总结

本文介绍了一种增强版的UNet架构,通过引入交叉注意力机制,使模型能够更有效地利用辅助特征。这种设计既保留了UNet原有的优势,又增加了灵活的特征融合能力,特别适合需要整合多源信息的复杂视觉任务。

在实际应用中,可以根据具体任务需求选择在哪些层级启用交叉注意力,也可以调整注意力模块的复杂度来平衡模型性能和计算开销。

希望这篇文章能帮助你理解交叉注意力在UNet中的应用。如果你有任何问题或建议,欢迎在评论区留言讨论!

完整代码

如下:

import torch.nn as nn
import torch
import mathclass CrossAttention(nn.Module):def __init__(self, channels):super(CrossAttention, self).__init__()self.query_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.key_conv = nn.Conv2d(channels, channels // 8, kernel_size=1)self.value_conv = nn.Conv2d(channels, channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x1, x2):"""x1: 主特征 (batch, channels, height, width)x2: 辅助特征 (batch, channels, height, width)"""batch_size, C, height, width = x1.size()# 投影到query, key, value空间proj_query = self.query_conv(x1).view(batch_size, -1, height * width).permute(0, 2, 1)  # (B, N, C')proj_key = self.key_conv(x2).view(batch_size, -1, height * width)  # (B, C', N)proj_value = self.value_conv(x2).view(batch_size, -1, height * width)  # (B, C, N)# 计算注意力图energy = torch.bmm(proj_query, proj_key)  # (B, N, N)attention = torch.softmax(energy / math.sqrt(proj_key.size(-1)), dim=-1)# 应用注意力out = torch.bmm(proj_value, attention.permute(0, 2, 1))  # (B, C, N)out = out.view(batch_size, C, height, width)# 残差连接out = self.gamma * out + x1return outclass DoubleConv(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(DoubleConv, self).__init__()self.use_cross_attention = use_cross_attentionself.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True)self.conv2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))if use_cross_attention:self.cross_attention = CrossAttention(out_channels)def forward(self, x, aux_feature=None):x = self.conv1(x)x = self.conv2(x)if self.use_cross_attention and aux_feature is not None:x = self.cross_attention(x, aux_feature)return xclass Down(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Down, self).__init__()self.downsampling = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2),DoubleConv(in_channels, out_channels, use_cross_attention))def forward(self, x, aux_feature=None):return self.downsampling[1](self.downsampling[0](x), aux_feature)class Up(nn.Module):def __init__(self, in_channels, out_channels, use_cross_attention=False):super(Up, self).__init__()self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels, use_cross_attention)def forward(self, x1, x2, aux_feature=None):x1 = self.upsampling(x1)x = torch.cat([x2, x1], dim=1)x = self.conv(x, aux_feature)return xclass UNetWithCrossAttention(nn.Module):def __init__(self, in_channels=1, num_classes=1, use_cross_attention=False):super(UNetWithCrossAttention, self).__init__()self.in_channels = in_channelsself.num_classes = num_classesself.use_cross_attention = use_cross_attention# 编码器self.in_conv = DoubleConv(in_channels, 64, use_cross_attention)self.down1 = Down(64, 128, use_cross_attention)self.down2 = Down(128, 256, use_cross_attention)self.down3 = Down(256, 512, use_cross_attention)self.down4 = Down(512, 1024, use_cross_attention)# 解码器self.up1 = Up(1024, 512, use_cross_attention)self.up2 = Up(512, 256, use_cross_attention)self.up3 = Up(256, 128, use_cross_attention)self.up4 = Up(128, 64, use_cross_attention)self.out_conv = OutConv(64, num_classes)def forward(self, x, aux_feature=None):# 编码过程x1 = self.in_conv(x, aux_feature)x2 = self.down1(x1, aux_feature)x3 = self.down2(x2, aux_feature)x4 = self.down3(x3, aux_feature)x5 = self.down4(x4, aux_feature)# 解码过程x = self.up1(x5, x4, aux_feature)x = self.up2(x, x3, aux_feature)x = self.up3(x, x2, aux_feature)x = self.up4(x, x1, aux_feature)x = self.out_conv(x)return x

http://www.lryc.cn/news/573898.html

相关文章:

  • 测试工程师实战:用 LangChain+deepseek构建多轮对话测试辅助聊天机器人
  • 2025-06-22 思考-人的意识与不断走向死亡的过程
  • P99延迟:系统性能优化的关键指标
  • AWS认证系列:考点解析 - cloud trail,cloud watch,aws config
  • MySQL之索引结构和分类深度详解
  • 【构建大型语言模型】
  • 鸿蒙 Column 组件指南:垂直布局核心技术与场景化实践
  • 【PyTorch项目实战】CycleGAN:无需成对训练样本,支持跨领域图像风格迁移
  • 《计算机网络:自顶向下方法(第8版)》Chapter 8 课后题
  • 华为云Flexus+DeepSeek征文|基于Dify构建解析网页写入Notion笔记工作流
  • 嵌入式C语言编程规范
  • Vue3解析Spring Boot ResponseEntity
  • select和poll用法解析
  • 如何仅用AI开发完整的小程序<4>—小程序页面创建与删除
  • 软件工程核心知识全景图:从需求到部署的系统化构建指南
  • 《算法笔记》之二(笔记)
  • DeepSeek:中国AI开源先锋的技术突破与行业革新
  • DeepSeek技术解析:开源大模型的创新突围之路
  • Unity中的Mathf.Clamp
  • 【unitrix】 4.0 类型级数值表示系统(types.rs)
  • 【已解决】 数据库INSERT操作时,Column count doesn’t match value count at row 1
  • 微处理器原理与应用篇---常见基础知识(6)
  • Redis-CPP 5大类型操作
  • 72、单元测试-常用测试注解
  • vue3 el-table 行字体颜色 根据字段改变
  • 在 Windows 和 Linux 下使用 C/C++ 连接 MySQL 的详细指南
  • SQL 中 HAVING COUNT (1)>1 与 HAVING COUNT (*)>1 的深度解析
  • Spring Boot Actuator 跟踪HTTP请求和响应
  • 开源 python 应用 开发(二)基于pyautogui、open cv 视觉识别的工具自动化
  • Python 的内置函数 help