当前位置: 首页 > news >正文

PyTorch深度学习实战——图像着色

PyTorch深度学习实战——图像着色

    • 0. 前言
    • 1. 模型与数据集分析
      • 1.1 数据集介绍
      • 1.2 模型策略
    • 2. 实现图像着色
    • 相关链接

0. 前言

图像着色指的是将黑白或灰度图像转换为彩色图像的过程,传统的图像处理技术通常基于直方图匹配和颜色传递的方法或基于用户交互的方法等完成图像着色操作,不但耗时且需要专业知识,而基于深度学习的方法能够实现自动着色,极大的提高了效率。在训练图着色模型时,我们可以将原始图像转换为黑白图像作为网络输入,原始彩色图像作为输出。

1. 模型与数据集分析

在本节中,我们将利用 CIFAR-10 数据集执行图像着色。

1.1 数据集介绍

CIFAR-10 数据集是一个广泛应用于计算机视觉领域的图像分类数据集。它由 10 个不同类别的彩色图像组成,每个类别包含 600032 x 32 像素的图像。该数据集涵盖了各种不同的对象类别,包括飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。与一些只包含灰度图像的数据集相比,CIFAR-10 数据集的图像是彩色的,但由于图像分辨率相对较低,图像中的细节和特征相对较少。
CIFAR-10 数据集在计算机视觉领域的研究和开发中得到了广泛的应用,许多图像分类算法和深度学习模型都在 CIFAR-10 上进行了测试和验证。它提供了一个标准化的基准,用于比较不同算法的性能。

1.2 模型策略

了解了所用数据集后,本节中,我们继续介绍图像着色模型策略:

  1. 获取训练数据集中的原始彩色图像,将其转换为灰度图像,构造输入(灰度)-输出(原始彩色图像)对
  2. 执行归一化输入和输出图像
  3. 构建 U-Net 架构
  4. 训练模型

2. 实现图像着色

接下来,使用 PyTorch 实现以上策略,构建图像着色模型。

(1) 导入所需库:

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'from torchvision import datasets
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torch import optim
import numpy as np
import torchvision
from matplotlib import pyplot as plt

(2) 下载数据集,并定义训练、验证数据集和数据加载器。

下载数据集:

data_folder = 'cifar10/cifar/' 
datasets.CIFAR10(data_folder, download=True)

定义训练、验证数据集和数据加载器:

class Colorize(torchvision.datasets.CIFAR10):def __init__(self, root, train):super().__init__(root, train)def __getitem__(self, ix):im, _ = super().__getitem__(ix)bw = im.convert('L').convert('RGB')bw, im = np.array(bw)/255., np.array(im)/255.bw, im = [torch.tensor(i).permute(2,0,1).to(device).float() for i in [bw,im]]return bw, imtrn_ds = Colorize('cifar10/cifar/', train=True)
val_ds = Colorize('cifar10/cifar/', train=False)trn_dl = DataLoader(trn_ds, batch_size=256, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=256, shuffle=False)

输入和输出图像的样本如下:

a,b = trn_ds[0]
plt.subplot(121)
plt.imshow(a.permute(1,2,0).cpu(), cmap='gray')
plt.subplot(122)
plt.imshow(b.permute(1,2,0).cpu())
plt.show()

样本示例
(3) 定义网络架构:

class Identity(nn.Module):def __init__(self):super().__init__()def forward(self, x):return xclass DownConv(nn.Module):def __init__(self, ni, no, maxpool=True):super().__init__()self.model = nn.Sequential(nn.MaxPool2d(2) if maxpool else Identity(),nn.Conv2d(ni, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),)def forward(self, x):return self.model(x)class UpConv(nn.Module):def __init__(self, ni, no, maxpool=True):super().__init__()self.convtranspose = nn.ConvTranspose2d(ni, no, 2, stride=2)self.convlayers = nn.Sequential(nn.Conv2d(no+no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(no, no, 3, padding=1),nn.BatchNorm2d(no),nn.LeakyReLU(0.2, inplace=True),)def forward(self, x, y):x = self.convtranspose(x)x = torch.cat([x,y], axis=1)x = self.convlayers(x)return xclass UNet(nn.Module):def __init__(self):super().__init__()self.d1 = DownConv( 3, 64, maxpool=False)self.d2 = DownConv( 64, 128)self.d3 = DownConv( 128, 256)self.d4 = DownConv( 256, 512)self.d5 = DownConv( 512, 1024)self.u5 = UpConv (1024, 512)self.u4 = UpConv ( 512, 256)self.u3 = UpConv ( 256, 128)self.u2 = UpConv ( 128, 64)self.u1 = nn.Conv2d(64, 3, kernel_size=1, stride=1)def forward(self, x):x0 = self.d1( x) # 32x1 = self.d2(x0) # 16x2 = self.d3(x1) # 8x3 = self.d4(x2) # 4x4 = self.d5(x3) # 2X4 = self.u5(x4, x3)# 4X3 = self.u4(X4, x2)# 8X2 = self.u3(X3, x1)# 16X1 = self.u2(X2, x0)# 32X0 = self.u1(X1) # 3return X0

(4) 定义模型、优化器和损失函数:

def get_model():model = UNet().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)loss_fn = nn.MSELoss()return model, optimizer, loss_fn

(5) 定义模型在批数据进行训练和验证的函数:

def train_batch(model, data, optimizer, criterion):model.train()x, y = data_y = model(x)optimizer.zero_grad()loss = criterion(_y, y)loss.backward()optimizer.step()return loss.item()@torch.no_grad()
def validate_batch(model, data, criterion):model.eval()x, y = data_y = model(x)loss = criterion(_y, y)return loss.item()

(6) 训练模型:

model, optimizer, criterion = get_model()
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)_val_dl = DataLoader(val_ds, batch_size=1, shuffle=True)n_epochs = 100
train_loss_epochs = []
val_loss_epochs = []for ex in range(n_epochs):N = len(trn_dl)trn_loss = []val_loss = []for bx, data in enumerate(trn_dl):loss = train_batch(model, data, optimizer, criterion)pos = (ex + (bx+1)/N)trn_loss.append(loss)train_loss_epochs.append(np.average(trn_loss))N = len(val_dl)for bx, data in enumerate(val_dl):loss = validate_batch(model, data, criterion)pos = (ex + (bx+1)/N)val_loss.append(loss)val_loss_epochs.append(np.average(val_loss))exp_lr_scheduler.step()if (ex+1)%10 == 0:for _ in range(5):a,b = next(iter(_val_dl))_b = model(a)plt.subplot(131)plt.imshow(a[0].permute(1,2,0).cpu(), cmap='gray')plt.subplot(132)plt.imshow(b[0].permute(1,2,0).cpu())plt.subplot(133)plt.imshow(_b[0].permute(1,2,0).detach().cpu().numpy())plt.show()
epochs = np.arange(n_epochs)+1
plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs, 'r', label='Test loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

着色结果

从前面的输出中,可以看到模型能够很好地为灰度图像着色。

相关链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割

http://www.lryc.cn/news/219905.html

相关文章:

  • InfiniBand 的前世今生
  • 分享一下微信小程序里怎么添加社区团购功能
  • 软考高项-IT部分
  • hugetlb核心组件
  • vscode配置环境变量
  • react:封装组件
  • 基于深度学习的视频多目标跟踪实现 计算机竞赛
  • linux中各种最新网卡2.5G网卡驱动,不同型号的网卡需要不同的驱动,整合各种网卡驱动,包括有线网卡、无线网卡、Wi-Fi热点
  • asp.net上传文件
  • JavaEE平台技术——预备知识(Web、Sevlet、Tomcat)
  • 基础课23——设计客服机器人
  • mybatis在springboot当中的使用
  • 如何处理前端本地存储和缓存
  • 导轨式安装压力应变桥信号处理差分信号输入转换变送器0-10mV/0-20mV/0-±10mV/0-±20mV转0-5V/0-10V/4-20mA
  • 人体姿态估计和手部姿态估计任务中神经网络的选择
  • odoo16 one2many字段的 domain
  • 一份优秀测试用例的设计策略
  • 自动驾驶行业观察之2023上海车展-----智驾供应链(3)
  • 倒计时丨3天后,我们直播间见!
  • c语言经典算法—二分查找,冒泡,选择,插入,归并,快排,堆排
  • 网站SSL证书有什么用
  • ubuntu 20.04 server安装
  • 造数工具调研
  • Linux文件系统目录结构
  • CANoe新建XML自动化Test Modules
  • 国内某发动机制造工厂RFID智能制造应用解决方案
  • 【SpringCloud Alibaba -- Nacos】Linux 搭建 Nacos 集群
  • 程序员使用 ChatGPT的 10 种最佳方式
  • 各种各类好用热门API推荐
  • 高速串行总线——SATA