当前位置: 首页 > news >正文

深度学习之超分辨率算法——SRGAN

  • 更新版本

  • 实现了生成对抗网络在超分辨率上的使用

  • 更新了损失函数,增加先验函数
    在这里插入图片描述

  • SRresnet实现

import torch
import torchvision
from torch import nnclass ConvBlock(nn.Module):def __init__(self, kernel_size=3, stride=1, n_inchannels=64):super(ConvBlock, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),stride=(stride, stride), bias=False, padding=(1, 1)),nn.BatchNorm2d(n_inchannels),nn.PReLU(),nn.Conv2d(in_channels=n_inchannels, out_channels=n_inchannels, kernel_size=(kernel_size, kernel_size),stride=(stride, stride), bias=False, padding=(1, 1)),nn.BatchNorm2d(n_inchannels),nn.PReLU(),)def forward(self, x):redisious = xout = self.sequential(x)return redisious + outclass Head_Conv(nn.Module):def __init__(self):super(Head_Conv, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),nn.PReLU(),)def forward(self, x):return self.sequential(x)class PixelShuffle(nn.Module):def __init__(self, n_channels=64, upscale_factor=2):super(PixelShuffle, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=n_channels, out_channels=n_channels * (upscale_factor ** 2), kernel_size=(3, 3),stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.BatchNorm2d(n_channels * (upscale_factor ** 2)),nn.PixelShuffle(upscale_factor=upscale_factor))def forward(self, x):return self.sequential(x)class Hidden_block(nn.Module):def __init__(self):super(Hidden_block, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.BatchNorm2d(64),)def forward(self, x):return self.sequential(x)class TailConv(nn.Module):def __init__(self):super(TailConv, self).__init__()self.sequential = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(9, 9), stride=(1, 1), padding=(9 // 2, 9 // 2)),nn.Tanh(),)def forward(self, x):return self.sequential(x)class SRResNet(nn.Module):def __init__(self, n_blocks=16):super(SRResNet, self).__init__()self.head = Head_Conv()self.resnet = list()for _ in range(n_blocks):self.resnet.append(ConvBlock(kernel_size=3, stride=1, n_inchannels=64))self.resnet = nn.Sequential(*self.resnet)self.hidden = Hidden_block()self.pixelShuufe = []for _ in range(2):self.pixelShuufe.append(PixelShuffle(n_channels=64, upscale_factor=2))self.pixelShuufe = nn.Sequential(*self.pixelShuufe)self.tail_conv = TailConv()def forward(self, x):head_out = self.head(x)resnet_out = self.resnet(head_out)out = head_out + resnet_outresult = self.pixelShuufe(out)out = self.tail_conv(result)return out

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = SRResNet()def forward(self, x):''':param x:lr_img:return: '''return self.model(x)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.hidden = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.AdaptiveAvgPool2d((6, 6)))self.out_layer = nn.Sequential(nn.Linear(512 * 6 * 6, 1024),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):result = self.hidden(x)# print(result.shape)result = result.reshape(result.shape[0], -1)out = self.out_layer(result)return out

SRGAN模型的生成器与判别器的实现


class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = SRResNet()def forward(self, x):''':param x:lr_img:return: '''return self.model(x)class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.hidden = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=(3 // 2, 3 // 2)),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),nn.BatchNorm2d(64),nn.LeakyReLU(),nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(128),nn.LeakyReLU(),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(256),nn.LeakyReLU(),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),nn.BatchNorm2d(512),nn.LeakyReLU(),nn.AdaptiveAvgPool2d((6, 6)))self.out_layer = nn.Sequential(nn.Linear(512 * 6 * 6, 1024),nn.LeakyReLU(negative_slope=0.2, inplace=True),nn.Linear(1024, 1),nn.Sigmoid())def forward(self, x):result = self.hidden(x)# print(result.shape)result = result.reshape(result.shape[0], -1)out = self.out_layer(result)return out```
- 针对VGG19 的层数截取
```python
class TruncatedVGG19(nn.Module):"""truncated VGG19网络,用于计算VGG特征空间的MSE损失"""def __init__(self, i, j):""":参数 i: 第 i 个池化层:参数 j: 第 j 个卷积层"""super(TruncatedVGG19, self).__init__()# 加载预训练的VGG模型vgg19 = torchvision.models.vgg19(pretrained=True)print(vgg19)maxpool_counter = 0conv_count = 0truncate_at = 0# 迭代搜索for layer in vgg19.features.children():truncate_at += 1# 统计if isinstance(layer, nn.Conv2d):conv_count += 1if isinstance(layer, nn.MaxPool2d):maxpool_counter += 1conv_counter = 0# 截断位置在第(i-1)个池化层之后(第 i 个池化层之前)的第 j 个卷积层if maxpool_counter == i - 1 and conv_count == j:break# 检查是否满足条件assert maxpool_counter == i - 1 and conv_count == j, "当前 i=%d 、 j=%d 不满足 VGG19 模型结构" % (i, j)# 截取网络self.truncated_vgg19 = nn.Sequential(*list(vgg19.features.children())[:truncate_at + 1])def forward(self, input):output = self.truncated_vgg19(input)  # (N, channels, _w,h)return output
```
http://www.lryc.cn/news/507460.html

相关文章:

  • 16.2、网络安全风险评估技术与攻击
  • 【项目管理】GDB调试
  • ChatGPT生成接口测试用例(一)
  • 2024 年 IA 技术大爆发深度解析
  • 如何进行js后台框架搭建(树形菜单,面包屑,全屏功能,刷新功能,监听页面刷新功能)
  • 多目标优化常用方法:pareto最优解
  • Vue.js实例开发-如何通过Props传递数据
  • 由popover框一起的操作demo问题
  • 人工智能ACA(四)--机器学习基础
  • uniapp图片数据流���� JFIF ��C 转化base64
  • django中cookie与session的使用
  • <项目代码>YOLO Visdrone航拍目标识别<目标检测>
  • GhostRace: Exploiting and Mitigating Speculative Race Conditions-记录
  • OPPO 数据分析面试题及参考答案
  • 腾讯云云开发 Copilot 深度探索与实战分享
  • Mac M1使用pip3安装报错
  • flask-admin的modelview 实现list列表视图中扩展修改状态按钮
  • 算法训练第二十三天|93. 复原 IP 地址 78. 子集 90. 子集 II
  • imu相机EKF
  • 【杂谈】虚拟机与EasyConnect运行巧设:Reqable助力指定应用流量专属化
  • 【AI系列】Paddle Speech安装指南
  • 【AI学习】OpenAI推出o3,向AGI迈出关键一步
  • 深度学习0-前置知识
  • Elasticsearch-分词器详解
  • Android-相对布局RelativeLayout
  • Centos7, 使用yum工具,出现 Could not resolve host: mirrorlist.centos.org
  • 在Linux中使用`scp`进行远程目录文件复制
  • VisionPro 机器视觉案例 之 连接件测量
  • C++ 中面向对象编程中对象的状态存储与恢复的处理
  • ip_output函数