当前位置: 首页 > news >正文

图像修复:深度学习GLCIC神经网络实现老照片划痕修复

第一步:GLCIC介绍

        GLCIC-PyTorch是一个基于PyTorch的开源项目,它实现了“全局和局部一致性图像修复”方法。该方法由Iizuka等人提出,主要用于图像修复任务,能够有效地恢复图像中被遮挡或损坏的部分。项目使用Python编程语言编写,并依赖于PyTorch深度学习框架。

第二步:GLCIC网络结构

        项目的核心功能是图像修复,它通过训练一个生成网络(Completion Network)和一个判别网络(Context Discriminator)来实现。生成网络负责完成图像修复任务,而判别网络则用于提高修复质量,确保修复后的图像在全局和局部上都与原始图像保持一致性。主要特点如下:

        图像修复:利用生成网络对图像中缺失的部分进行修复。
        全局与局部一致性:确保修复后的图像既在全局上与原图一致,又在局部细节上保持连贯。
        判别网络辅助:通过判别网络对生成图像进行评估,以提升修复质量。

第三步:模型代码展示

import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import Flatten, Concatenateclass CompletionNetwork(nn.Module):def __init__(self):super(CompletionNetwork, self).__init__()# input_shape: (None, 4, img_h, img_w)self.conv1 = nn.Conv2d(4, 64, kernel_size=5, stride=1, padding=2)self.bn1 = nn.BatchNorm2d(64)self.act1 = nn.ReLU()# input_shape: (None, 64, img_h, img_w)self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)self.bn2 = nn.BatchNorm2d(128)self.act2 = nn.ReLU()# input_shape: (None, 128, img_h//2, img_w//2)self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)self.bn3 = nn.BatchNorm2d(128)self.act3 = nn.ReLU()# input_shape: (None, 128, img_h//2, img_w//2)self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)self.bn4 = nn.BatchNorm2d(256)self.act4 = nn.ReLU()# input_shape: (None, 256, img_h//4, img_w//4)self.conv5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)self.bn5 = nn.BatchNorm2d(256)self.act5 = nn.ReLU()# input_shape: (None, 256, img_h//4, img_w//4)self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)self.bn6 = nn.BatchNorm2d(256)self.act6 = nn.ReLU()# input_shape: (None, 256, img_h//4, img_w//4)self.conv7 = nn.Conv2d(256, 256, kernel_size=3, stride=1, dilation=2, padding=2)self.bn7 = nn.BatchNorm2d(256)self.act7 = nn.ReLU()# input_shape: (None, 256, img_h//4, img_w//4)self.conv8 = nn.Conv2d(256, 256, kernel_size=3, stride=1, dilation=4, padding=4)self.bn8 = nn.BatchNorm2d(256)self.act8 = nn.ReLU()# input_shape: (None, 256, img_h//4, img_w//4)self.conv9 = nn.Conv2d(256, 256, kernel_size=3, stride=1, dilation=8, padding=8)self.bn9 = nn.BatchNorm2d(256)self.act9 = nn.ReLU()# input_shape: (None, 256, img_h//4, img_w//4)self.conv10 = nn.Conv2d(256, 256, kernel_size=3, stride=1, dilation=16, padding=16)self.bn10 = nn.BatchNorm2d(256)self.act10 = nn.ReLU()# input_shape: (None, 256, img_h//4, img_w//4)self.conv11 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)self.bn11 = nn.BatchNorm2d(256)self.act11 = nn.ReLU()# input_shape: (None, 256, img_h//4, img_w//4)self.conv12 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)self.bn12 = nn.BatchNorm2d(256)self.act12 = nn.ReLU()# input_shape: (None, 256, img_h//4, img_w//4)self.deconv13 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)self.bn13 = nn.BatchNorm2d(128)self.act13 = nn.ReLU()# input_shape: (None, 128, img_h//2, img_w//2)self.conv14 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)self.bn14 = nn.BatchNorm2d(128)self.act14 = nn.ReLU()# input_shape: (None, 128, img_h//2, img_w//2)self.deconv15 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)self.bn15 = nn.BatchNorm2d(64)self.act15 = nn.ReLU()# input_shape: (None, 64, img_h, img_w)self.conv16 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)self.bn16 = nn.BatchNorm2d(32)self.act16 = nn.ReLU()# input_shape: (None, 32, img_h, img_w)self.conv17 = nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1)self.act17 = nn.Sigmoid()# output_shape: (None, 3, img_h. img_w)def forward(self, x):x = self.bn1(self.act1(self.conv1(x)))x = self.bn2(self.act2(self.conv2(x)))x = self.bn3(self.act3(self.conv3(x)))x = self.bn4(self.act4(self.conv4(x)))x = self.bn5(self.act5(self.conv5(x)))x = self.bn6(self.act6(self.conv6(x)))x = self.bn7(self.act7(self.conv7(x)))x = self.bn8(self.act8(self.conv8(x)))x = self.bn9(self.act9(self.conv9(x)))x = self.bn10(self.act10(self.conv10(x)))x = self.bn11(self.act11(self.conv11(x)))x = self.bn12(self.act12(self.conv12(x)))x = self.bn13(self.act13(self.deconv13(x)))x = self.bn14(self.act14(self.conv14(x)))x = self.bn15(self.act15(self.deconv15(x)))x = self.bn16(self.act16(self.conv16(x)))x = self.act17(self.conv17(x))return xclass LocalDiscriminator(nn.Module):def __init__(self, input_shape):super(LocalDiscriminator, self).__init__()self.input_shape = input_shapeself.output_shape = (1024,)self.img_c = input_shape[0]self.img_h = input_shape[1]self.img_w = input_shape[2]# input_shape: (None, img_c, img_h, img_w)self.conv1 = nn.Conv2d(self.img_c, 64, kernel_size=5, stride=2, padding=2)self.bn1 = nn.BatchNorm2d(64)self.act1 = nn.ReLU()# input_shape: (None, 64, img_h//2, img_w//2)self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2)self.bn2 = nn.BatchNorm2d(128)self.act2 = nn.ReLU()# input_shape: (None, 128, img_h//4, img_w//4)self.conv3 = nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2)self.bn3 = nn.BatchNorm2d(256)self.act3 = nn.ReLU()# input_shape: (None, 256, img_h//8, img_w//8)self.conv4 = nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2)self.bn4 = nn.BatchNorm2d(512)self.act4 = nn.ReLU()# input_shape: (None, 512, img_h//16, img_w//16)self.conv5 = nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2)self.bn5 = nn.BatchNorm2d(512)self.act5 = nn.ReLU()# input_shape: (None, 512, img_h//32, img_w//32)in_features = 512 * (self.img_h//32) * (self.img_w//32)self.flatten6 = Flatten()# input_shape: (None, 512 * img_h//32 * img_w//32)self.linear6 = nn.Linear(in_features, 1024)self.act6 = nn.ReLU()# output_shape: (None, 1024)def forward(self, x):x = self.bn1(self.act1(self.conv1(x)))x = self.bn2(self.act2(self.conv2(x)))x = self.bn3(self.act3(self.conv3(x)))x = self.bn4(self.act4(self.conv4(x)))x = self.bn5(self.act5(self.conv5(x)))x = self.act6(self.linear6(self.flatten6(x)))return xclass GlobalDiscriminator(nn.Module):def __init__(self, input_shape, arc='celeba'):super(GlobalDiscriminator, self).__init__()self.arc = arcself.input_shape = input_shapeself.output_shape = (1024,)self.img_c = input_shape[0]self.img_h = input_shape[1]self.img_w = input_shape[2]# input_shape: (None, img_c, img_h, img_w)self.conv1 = nn.Conv2d(self.img_c, 64, kernel_size=5, stride=2, padding=2)self.bn1 = nn.BatchNorm2d(64)self.act1 = nn.ReLU()# input_shape: (None, 64, img_h//2, img_w//2)self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2)self.bn2 = nn.BatchNorm2d(128)self.act2 = nn.ReLU()# input_shape: (None, 128, img_h//4, img_w//4)self.conv3 = nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2)self.bn3 = nn.BatchNorm2d(256)self.act3 = nn.ReLU()# input_shape: (None, 256, img_h//8, img_w//8)self.conv4 = nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2)self.bn4 = nn.BatchNorm2d(512)self.act4 = nn.ReLU()# input_shape: (None, 512, img_h//16, img_w//16)self.conv5 = nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2)self.bn5 = nn.BatchNorm2d(512)self.act5 = nn.ReLU()# input_shape: (None, 512, img_h//32, img_w//32)if arc == 'celeba':in_features = 512 * (self.img_h//32) * (self.img_w//32)self.flatten6 = Flatten()self.linear6 = nn.Linear(in_features, 1024)self.act6 = nn.ReLU()elif arc == 'places2':self.conv6 = nn.Conv2d(512, 512, kernel_size=5, stride=2, padding=2)self.bn6 = nn.BatchNorm2d(512)self.act6 = nn.ReLU()# input_shape (None, 512, img_h//64, img_w//64)in_features = 512 * (self.img_h//64) * (self.img_w//64)self.flatten7 = Flatten()self.linear7 = nn.Linear(in_features, 1024)self.act7 = nn.ReLU()else:raise ValueError('Unsupported architecture \'%s\'.' % self.arc)# output_shape: (None, 1024)def forward(self, x):x = self.bn1(self.act1(self.conv1(x)))x = self.bn2(self.act2(self.conv2(x)))x = self.bn3(self.act3(self.conv3(x)))x = self.bn4(self.act4(self.conv4(x)))x = self.bn5(self.act5(self.conv5(x)))if self.arc == 'celeba':x = self.act6(self.linear6(self.flatten6(x)))elif self.arc == 'places2':x = self.bn6(self.act6(self.conv6(x)))x = self.act7(self.linear7(self.flatten7(x)))return xclass ContextDiscriminator(nn.Module):def __init__(self, local_input_shape, global_input_shape, arc='celeba'):super(ContextDiscriminator, self).__init__()self.arc = arcself.input_shape = [local_input_shape, global_input_shape]self.output_shape = (1,)self.model_ld = LocalDiscriminator(local_input_shape)self.model_gd = GlobalDiscriminator(global_input_shape, arc=arc)# input_shape: [(None, 1024), (None, 1024)]in_features = self.model_ld.output_shape[-1] + self.model_gd.output_shape[-1]self.concat1 = Concatenate(dim=-1)# input_shape: (None, 2048)self.linear1 = nn.Linear(in_features, 1)self.act1 = nn.Sigmoid()# output_shape: (None, 1)def forward(self, x):x_ld, x_gd = xx_ld = self.model_ld(x_ld)x_gd = self.model_gd(x_gd)out = self.act1(self.linear1(self.concat1([x_ld, x_gd])))return out

第四步:运行交互代码

第五步:整个工程的内容

 项目完整文件下载请见演示与介绍视频的简介处给出:➷➷➷

图像修复:深度学习GLCIC神经网络实现老照片划痕修复_哔哩哔哩_bilibili​

http://www.lryc.cn/news/587677.html

相关文章:

  • RNN(循环神经网络)
  • 【git fetch submodule报错】Errors during submodule fetch 如何解决?
  • VUE export import
  • 【算法深练】BFS:“由近及远”的遍历艺术,广度优先算法题型全解析
  • 人工智能如何重构能源系统以应对气候变化?
  • 从数据洞察到设计创新:UI前端如何利用数字孪生提升产品交互体验?
  • Pythonic:Python 语言习惯和哲学的代码风格
  • vue中使用西瓜播放器xgplayer (封装)+xgplayer-hls 播放.m3u8格式视频
  • Vue+axios
  • Rust语言实战:LeetCode算法精解
  • 从“炼丹”到“流水线”——如何用Prompt Engineering把LLM微调成本打下来?
  • 内容管理系统指南:企业内容运营的核心引擎
  • Retinex视网膜算法(SSR、MSR、MSRCR)
  • JVM监控及诊断工具-命令行篇
  • AI香烟检测实战:YOLO11模型训练全过程解析
  • 【第一章编辑器开发基础第一节绘制编辑器元素_7折叠面板控件(7/7)】
  • python学智能算法(十八)|SVM基础概念-向量点积
  • 【第一章编辑器开发基础第二节编辑器布局_3GUI元素和布局大小(3/4)】
  • python学智能算法(十七)|SVM基础概念-向量的值和方向
  • CISSP通过回顾
  • Unity中HumanBodyBones骨骼对照
  • [Nagios Core] 通知系统 | 事件代理 | NEB模块,事件,回调
  • 上下文管理器 和 contextlib 模块
  • Cocos Creator 高斯模糊效果实现解析
  • 2025高防CDN硬核防御指南:AI+量子加密如何终结DDoS/CC攻击?
  • VyOS起步指南:用Docker快速搭建网络实验环境
  • MCP终极篇!MCP Web Chat项目实战分享
  • android tabLayout 切换fragment fragment生命周期
  • VScode设计平台demo&前端开发中的常见问题
  • CentOS系统哪些版本?分别适用于那些业务或网站类型?