当前位置: 首页 > news >正文

Diffusion的unet中用到的AttentionBlock详解

AttentionBlock

  • torch.split
  • torch中的permute的用法
    • torch.transpose()
    • view()
  • torch.bmm
  • softmax(x, dim=-1)

Diffusion的unet中用到的AttentionBlock详解

class AttentionBlock(nn.Module):__doc__ = r"""Applies QKV self-attention with a residual connection.Input:x: tensor of shape (N, in_channels, H, W)norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"num_groups (int): number of groups used in group normalization. Default: 32Output:tensor of shape (N, in_channels, H, W)Args:in_channels (int): number of input channels"""def __init__(self, in_channels, norm="gn", num_groups=32):super().__init__()self.in_channels = in_channelsself.norm = get_norm(norm, in_channels, num_groups)# 为啥这里的QKV并不是一样的???而是把通道数翻了3倍self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)self.to_out = nn.Conv2d(in_channels, in_channels, 1)def forward(self, x):b, c, h, w = x.shapeq, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)q = q.permute(0, 2, 3, 1).view(b, h * w, c)k = k.view(b, c, h * w)v = v.permute(0, 2, 3, 1).view(b, h * w, c)dot_products = torch.bmm(q, k) * (c ** (-0.5))assert dot_products.shape == (b, h * w, h * w)attention = torch.softmax(dot_products, dim=-1)out = torch.bmm(attention, v)assert out.shape == (b, h * w, c)out = out.view(b, h, w, c).permute(0, 3, 1, 2)return self.to_out(out) + x

x: (batch, channel, h, w)
经过to_qkv操作,变成了(batch, channel*3, h, w)

torch.split

torch.split(tensor, split_size_or_sections, dim=0)
# 作用:将tensor分成块结构
'''
split_size_or_secctions: 即多少个为一组
dim: 对哪个维度进行划分
'''

eg:
q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
即对大小为(batch, channel*3, h, w)的张量,在dim=1上划分,每channel个为一组
所以,q, k, v 的形状均为(batch, channel, h, w)

torch.split详解

torch中的permute的用法

作用:permute可以对tensor进行转置

import torch
import torch.nn as nnx = torch.randn(1, 2, 3, 4)
print(x.size())    # torch.Size([1, 2, 3, 4])   
print(x.permute(2, 1, 0, 3).size())# torch.Size([3, 2, 1, 4])   

torch.transpose()

因为torch.transpose 一次只能进行两个维度的转置,如果需要多个维度的转置,那么需要多次调用transpose()。比如上述的tensor[1,2,3,4]转置为tensor[3,4,1,2],使用transpose需要做如下:

x.transpose(0,2).transpose(1,3)

view()

view()函数作用的内存必须是连续的,如果操作数不是连续存储的,必须在操作之前执行contiguous(),把tensor变成在内存中连续分布的形式;view的功能有点像reshape,可以对tensor进行重新塑型

import torch
import torch.nn as nn
import numpy as npy = np.array([[[1, 2, 3], [4, 5, 6]]]) # 1X2X3
y_tensor = torch.tensor(y)
y_tensor_trans = y_tensor.permute(2, 0, 1) # 3X1X2
print(y_tensor.size())
print(y_tensor_trans.size())print(y_tensor)
print(y_tensor_trans)
print(y_tensor.view(1, 3, 2)) 
torch.Size([1, 2, 3])
torch.Size([3, 1, 2])
tensor([[[1, 2, 3],[4, 5, 6]]])
tensor([[[1, 4]],[[2, 5]],[[3, 6]]])
tensor([[[1, 2],[3, 4],[5, 6]]])

permute参考
permute详解参考

torch.bmm

作用:
计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m) 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,对于剩下的则不做要求,输出维度 (b,h,m)

torch.bmm要求a,b的维度必须是3维的,不能为2D or 4D

矩阵相乘

softmax(x, dim=-1)

import torch 
a = torch.randn(2,3)
print(a)
tensor([[-8.2976e-01,  5.8105e-04,  1.2218e+00],[ 1.9745e-01,  1.2727e+00,  5.9587e-01]])
b = torch.softmax(a, dim=-1)
print(b)
tensor([[0.0903, 0.2072, 0.7025],[0.1845, 0.5407, 0.2748]])

softmax(x, dim=-1)

http://www.lryc.cn/news/44768.html

相关文章:

  • ElasticSearch索引文档写入和近实时搜索
  • 【C语言蓝桥杯每日一题】——等差数列
  • EM7电磁铁的技术参数
  • 选择很重要,骑友,怎么挑选骑行装备?
  • 【JUC面试题】Java并发编程面试题
  • spark笔记
  • 丢失了packet.dll原因和解决方法全面指南
  • 算法练习随记(三)
  • 基于Python 进行卫星图像多种指数分析
  • (Week 15)综合复习(C++,字符串,数学)
  • 迪赛智慧数——柱状图(正负条形图):“光棍”排行榜TOP10省份
  • IDEA集成chatGTP让你编码如虎添翼
  • Python3 os.close() 方法、Python3 File readline() 方法
  • Vision Pro 自己写的一些自定义工具(c#)
  • ARM/FPGA/DSP板卡选型大全,总有一款适合您
  • 【C语言蓝桥杯每日一题】—— 既约分数
  • 【机器学习】线性回归
  • 用ChatGPT学习多传感器融合中的基础知识
  • PyCharm2020介绍
  • Le Potato + Jumbospot MMDVM热点盒子
  • 蓝桥杯第19天(Python)(疯狂刷题第2天)
  • (五)手把手带你搭建精美简洁的个人时间管理网站—基于Axure的首页原型设计
  • 阿里面试:为什么MySQL不建议使用delete删除数据?
  • 低代码开发公司:用科技强力开启产业分工新时代!
  • 参考mfa官方文档实践笔记(亲测)
  • 【 第六章 拦截器,注解配置springMVC,springMVC执行流程】
  • 一种编译器视角下的python性能优化
  • 太逼真!这个韩国虚拟女团你追不追?
  • 安全与道路测试:自动驾驶系统安全性探究
  • chatGPT学英语,真香!!!