当前位置: 首页 > news >正文

dl转置卷积

转置卷积

转置卷积,顾名思义,通过名字我们应该就能看出来,其作用和卷积相反,它可以使得图像的像素增多
在这里插入图片描述
上图的意思是,输入是22的图像,卷积核为22的矩阵,然后变换成3*3的矩阵
代码如下

import torch
from torch import nn
from d2l import torch as d2ldef trans_conv(X, K):  #X是原始矩阵,K是转置卷积核h, w = K.shapeY = torch.zeros((X.shape[0] + h - 1, X.shape[1] + w - 1))  # 转置卷积后的大小为x.shape[0] + k.shape[0] - 1 .........for i in range(X.shape[0]):for j in range(X.shape[1]):Y[i: i+h, j: j+w] += X[i, j] * Kreturn Y
X = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
K = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
trans_conv(X, K)

在这里插入图片描述
传统输入可能都是四维,使用API一样的

# 四维的话,调用API一样的
X, K = X.reshape(1, 1, 2, 2), K.reshape(1, 1, 2, 2)
tconv = nn.ConvTranspose2d(1, 1, kernel_size=2, bias=False)
tconv.weight.data = K
tconv(X)

在这里插入图片描述
与常规卷积不同,在转置卷积中,填充被应用于的输出(常规卷积将填充应用于输入)。
例如,当将高和宽两侧的填充数指定为1时,转置卷积的输出中将删除第一和最后的行与列。
换句话说,转置卷积的padding是删除输出的一圈

X, K = X.reshape(1, 1, 2, 2), K.reshape(1, 1, 2, 2)
tconv = nn.ConvTranspose2d(1, 1, kernel_size=2, padding=1, bias=False)
tconv.weight.data = K
tconv(X)

在这里插入图片描述
如果步幅为2的话,那么就会是一个4*4的矩阵

# 步幅为2的话那就是4*4了
X, K = X.reshape(1, 1, 2, 2), K.reshape(1, 1, 2, 2)
tconv = nn.ConvTranspose2d(1, 1, kernel_size=2, stride=2, bias=False)
tconv.weight.data = K
tconv(X)

在这里插入图片描述
对于多个输入和输出通道,转置卷积与常规卷积以相同方式运作。 假设输入有ci个通道,且转置卷积为每个输入通道分配了一个kwkh的卷积核张量。
当指定多个输出通道时,每个输出通道的卷积核shape为ci
kw*kh

接下来我们可能会想,转置卷积为何以矩阵变换命名呢?我们先来看看矩阵乘法如何实现卷积
这是传统卷积

X = torch.arange(9.0).reshape(3, 3)
K = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
Y = d2l.corr2d(X, K)
Y

在这里插入图片描述
接下来通过矩阵乘法计算

# 先将K 写成稀疏权重矩阵
def kernel2matrix(K):k, W = torch.zeros(5), torch.zeros((4, 9))  # W是4*9的k[:2], k[3:5] = K[0, :], K[1, :]W[0, :5], W[1, 1:6], W[2, 3:8], W[3, 4:] = k, k, k, kreturn W
W = kernel2matrix(K)
W

在这里插入图片描述

# 然后就是矩阵乘法
Y == torch.matmul(W, X.reshape(-1)).reshape(2, 2)

在这里插入图片描述

而如果我们用W的转置*Y,那就是原来的Y的转置卷积了

# 同样的,我们可以使用矩阵乘法来实现转置矩阵  Y 是卷积后的值
Z = trans_conv(Y, K)
Z == torch.matmul(W.T, Y.reshape(-1)).reshape(3, 3)

在这里插入图片描述

http://www.lryc.cn/news/268188.html

相关文章:

  • 详解结构体(包含结构体内存对齐,柔性数组,位段)【尊嘟很详细】
  • 我的NPI项目之Android系统升级 - 同平台多产品的OTA
  • pnpm包管理器
  • flutter websocket发送ping包?
  • 基于采样的自动驾驶规划算法 - PRM,RRT,RRT*,CL-RRT
  • CGAL的D维范围树和线段树
  • 005.HCIA 传输层
  • LLM之RAG实战(八)| 使用Neo4j和LlamaIndex实现多模态RAG
  • 【SpringCloud笔记】(10)消息总线之Bus
  • 超酷的爬虫可视化界面
  • 【kafka消息里会有乱序消费的情况吗?如果有,是怎么解决的?】
  • 【PID精讲12】基于MATLAB和Simulink的仿真教程
  • 手机无人直播:解放直播的新方式
  • ios 之 数据库、地理位置、应用内跳转、推送、制作静态库、CoreData
  • Django(三)
  • vscode括号颜色突然变成白色的了,怎么解决
  • 测试服务器带宽(ubuntu)
  • 【WPF】使用Behavior以及ValidationRule实现表单校验
  • ArcGIS渔网的多种用法
  • C++ 中使用 std::map 的一个示例
  • python虚拟环境及其在项目实践中的应用
  • 普中STM32-PZ6806L开发板(烧录方式)
  • 基于单片机设计的指纹锁(读取、录入、验证指纹)
  • HarmonyOS - 基础组件绘制
  • AR智慧校园三维主电子沙盘系统研究及应用
  • web前端项目-七彩夜空烟花【附源码】
  • 在k8s中将gitlab-runner的运行pod调度到指定节点
  • 1.解决父组件传数据给子组件太慢,导致子组件获取不到合适数据渲染出错问题2.vue中props传递异步数据,子组件用watch监听
  • SpringMVC之获取请求参数和域对象共享数据
  • IntelliJ IDEA Community(社区版)下载及安装自用版