当前位置: 首页 > news >正文

MDTA模块(Restormer)

From a layer normalized tensor Y ∈ R H ^ × W ^ × C ^ \mathbf{Y} \in \mathbb{R}^{\hat{H} \times \hat{W} \times \hat{C}} YRH^×W^×C^, our MDTA first generates query ( Q ) (\mathbf{Q}) (Q), key ( K ) (\mathbf{K}) (K) and value ( V ) (\mathbf{V}) (V) projections, enriched with local context. It is achieved by applying 1 × 1 1 \times 1 1×1 convolutions to aggregate pixel-wise cross-channel context followed by 3 × 3 3 \times 3 3×3 depth-wise convolutions to encode channel-wise spatial context, yielding Q = W d Q W p Q Y , K = W d K W p K Y \mathbf{Q}=W_d^Q W_p^Q \mathbf{Y}, \mathbf{K}=W_d^K W_p^K \mathbf{Y} Q=WdQWpQY,K=WdKWpKY and V = W d V W p V Y \mathbf{V}=W_d^V W_p^V \mathbf{Y} V=WdVWpVY. Where W p ( ⋅ ) W_p^{(\cdot)} Wp() is the 1 × 1 1 \times 1 1×1 point-wise convolution and W d ( ⋅ ) W_d^{(\cdot)} Wd() is the 3 × 3 3 \times 3 3×3 depth-wise convolution. We use bias-free convolutional layers in the network. Next, we reshape query and key projections such that their dot-product interaction generates a transposed-attention map A \mathbf{A} A of size R C ^ × C ^ \mathbb{R}^{\hat{C} \times \hat{C}} RC^×C^, instead of the huge regular attention map of size R H ^ W ^ × H ^ W ^ \mathbb{R}^{\hat{H} \hat{W} \times \hat{H} \hat{W}} RH^W^×H^W^. Overall, the MDTA process is defined as:
X ^ = W p Attention ⁡ ( Q ^ , K ^ , V ^ ) + X Attention ⁡ ( Q ^ , K ^ , V ^ ) = V ^ ⋅ Softmax ⁡ ( K ^ ⋅ Q ^ / α ) \hat{\mathbf{X}}=W_p \operatorname{Attention}(\hat{\mathbf{Q}}, \hat{\mathbf{K}}, \hat{\mathbf{V}})+\mathbf{X}\\ \operatorname{Attention}(\hat{\mathbf{Q}}, \hat{\mathbf{K}}, \hat{\mathbf{V}})=\hat{\mathbf{V}} \cdot \operatorname{Softmax}(\hat{\mathbf{K}} \cdot \hat{\mathbf{Q}} / \alpha) X^=WpAttention(Q^,K^,V^)+XAttention(Q^,K^,V^)=V^Softmax(K^Q^/α)
where X \mathbf{X} X and X ^ \hat{\mathbf{X}} X^ are the input and output feature maps; Q ^ ∈ R H ^ W ^ × C ^ ; K ^ ∈ R C ^ × H ^ W ^ ; \hat{\mathbf{Q}} \in \mathbb{R}^{\hat{H} \hat{W} \times \hat{C}} ; \hat{\mathbf{K}} \in \mathbb{R}^{\hat{C} \times \hat{H} \hat{W}} ; Q^RH^W^×C^;K^RC^×H^W^; and V ^ ∈ R H ^ W ^ × C ^ \hat{\mathbf{V}} \in \mathbb{R}^{\hat{H} \hat{W} \times \hat{C}} V^RH^W^×C^ matrices are obtained after reshaping tensors from the original size R H ^ × W ^ × C ^ \mathbb{R}^{\hat{H} \times \hat{W} \times \hat{C}} RH^×W^×C^. Here, α \alpha α is a learnable scaling parameter to control the magnitude of the dot product of K ^ \hat{\mathbf{K}} K^ and Q ^ \hat{\mathbf{Q}} Q^ before applying the softmax function. Similar to the conventional multi-head SA , we divide the number of channels into ‘heads’ and learn separate attention maps in parallel.

## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):def __init__(self, dim, num_heads, bias):super(Attention, self).__init__()self.num_heads = num_headsself.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)def forward(self, x):b, c, h, w = x.shapeqkv = self.qkv_dwconv(self.qkv(x))q, k, v = qkv.chunk(3, dim=1)q = rearrange(q, 'b (head c) h w -> b head c (h w)',head=self.num_heads)k = rearrange(k, 'b (head c) h w -> b head c (h w)',head=self.num_heads)v = rearrange(v, 'b (head c) h w -> b head c (h w)',head=self.num_heads)q = torch.nn.functional.normalize(q, dim=-1)k = torch.nn.functional.normalize(k, dim=-1)attn = (q @ k.transpose(-2, -1)) * self.temperatureattn = attn.softmax(dim=-1)out = (attn @ v)out = rearrange(out, 'b head c (h w) -> b (head c) h w',head=self.num_heads, h=h, w=w)out = self.project_out(out)return out

这段代码并没有实现图中的Norm模块,该模块的实现可以参考Layer Normalization(层规范化)。我们看一下Transformer Block是如何包装的:

class TransformerBlock(nn.Module):def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):super(TransformerBlock, self).__init__()self.norm1 = LayerNorm(dim, LayerNorm_type)self.attn = Attention(dim, num_heads, bias)self.norm2 = LayerNorm(dim, LayerNorm_type)self.ffn = FeedForward(dim, ffn_expansion_factor, bias)def forward(self, x):x = x + self.attn(self.norm1(x))#MDTAx = x + self.ffn(self.norm2(x))return x

可以看到实现的时候是先Norm,然后通过Attention,最后再残差连接,这整个流程才是上图所示

http://www.lryc.cn/news/138100.html

相关文章:

  • C++ 新特性 | C++ 11 | decltype 关键字
  • 2023国赛数学建模思路 - 案例:退火算法
  • ubuntu20.04 编译安装运行emqx
  • ARM linux ALSA 音频驱动开发方法
  • 设计模式二十三:模板方法模式(Template Method Pattern)
  • [Linux]进程状态
  • Python爬虫逆向实战案例(五)——YRX竞赛题第五题
  • js识别图片中的文字插件 tesseract.js
  • Linux设备驱动移植(设备数)
  • 【移动端开发】鸿蒙系统开发入门:代码示例与详解
  • Jenkins的流水线详解
  • DIFFEDIT-图像编辑论文解读
  • 【优选算法】—— 字符串匹配算法
  • Docker容器:docker consul的注册与发现及consul-template守护进程
  • Blazor 依赖注入妙用:巧设回调
  • Python 基础 -- Tutorial(三)
  • 基于STM32的四旋翼无人机项目(二):MPU6050姿态解算(含上位机3D姿态显示教学)
  • 微信小程序开发教学系列(1)- 开发入门
  • Nginx虚拟主机(server块)部署Vue项目
  • JAVA开发环境接口swagger-ui使用总结
  • mongodb 数据库管理(数据库、集合、文档)
  • 分布式与集群的定义及异同
  • 电脑端teams一直在线小程序,简单好用易上手
  • YOLOv5算法改进(4)— 添加CA注意力机制
  • 无涯教程-PHP - XML GET
  • Spark Standalone环境搭建及测试
  • 【PHP】流程控制-ifswitchforwhiledo-whilecontinuebreak
  • Pytorch-day04-模型构建-checkpoint
  • 使用Xshell7控制多台服务同时安装ZK最新版集群服务
  • python numpy array dtype和astype类型转换的区别