当前位置: 首页 > news >正文

G4 - 可控手势生成 CGAN

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目录

  • 代码
  • 总结与心得


代码

关于CGAN的原理上节已经讲过,这次主要是编写代码加载上节训练后的模型来进行指定条件的生成

图像的生成其实只需要使用Generator模型,判别器模型是在训练过程中才用的。

# 库引入
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torchdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 超参数
latent_dim = 100
n_classes = 3
embedding_dim = 100# 工具函数
def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:torch.nn.init.normal_(m.weight, 0.0, 0.02)elif classname.find('BatchNorm') != -1:torch.nn.init.normal_(m.weight, 1.0, 0.02)torch.nn.init.zeros_(m.bias)# 模型
class Generator(nn.Module):def __init__(self):super().__init__()self.label_conditioned_generator = nn.Sequential(nn.Embedding(n_classes, embedding_dim),nn.Linear(embedding_dim, 16))self.latent = nn.Sequential(nn.Linear(latent_dim, 4*4*512),nn.LeakyReLU(0.2, inplace=True))self.model = nn.Sequential(nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),nn.ReLU(True),nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),nn.ReLU(True),nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),nn.ReLU(True),nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),nn.ReLU(True),nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),nn.Tanh())def forward(self, inputs):noise_vector, label = inputslabel_output = self.label_conditioned_generator(label)label_output = label_output.view(-1, 1, 4, 4)latent_output = self.latent(noise_vector)latent_output = latent_output.view(-1, 512, 4, 4)concat = torch.cat((latent_output, label_output), dim=1)image = self.model(concat)return imagegenerator = Generator().to(device)
generator.apply(weights_init)
print(generator)
Generator((label_conditioned_generator): Sequential((0): Embedding(3, 100)(1): Linear(in_features=100, out_features=16, bias=True))(latent): Sequential((0): Linear(in_features=100, out_features=8192, bias=True)(1): LeakyReLU(negative_slope=0.2, inplace=True))(model): Sequential((0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)(5): ReLU(inplace=True)(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)(8): ReLU(inplace=True)(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)(11): ReLU(inplace=True)(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)(13): Tanh())
)
from numpy.random import randint, randn
from numpy import linspace
from matplotlib import pyplot, gridspec# 加载训练好的权重
generator.load_state_dict(torch.load('generator_epoch_300.pth'), strict=False)
# 关闭梯度积累
generator.eval()# 生成随机变量
interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)# 生成条件变量
label = 0 # 生成第0个分类的图像
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()# 执行生成
predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()# 屏蔽警告
import warnings
warnings.filterwarnings('ignore')# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
# 防止负号无法显示
plt.rcParams['axes.unicode_minus']= False
# 设置图的分辨率
plt.rcParams['figure.dpi'] = 100# 绘图
plt.figure(figsize=(8, 3))
pred = (predictions[0, :, :, :] + 1) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

生成分类0
我们将分类修改为1重新生成一次

生成分类1

总结与心得

在本次实验的过程中,我了解了CGAN模型在训练完成后,后续如何使用的步骤:

  1. 保存训练好的生成器的权重
  2. 使用生成器加载
  3. 生成随机分布变量用于生成图像
  4. 生成指定的标签,并转换成控制向量
  5. 执行生成操作

另外关于警告和matplotlib设置中文字体的方式也是经常会用到的技巧。

http://www.lryc.cn/news/358918.html

相关文章:

  • 使用 DuckDuckGo API 实现多种搜索功能
  • 【DrissionPage爬虫库 1】两种模式分别爬取Gitee开源项目
  • leetcode 115.不同的子序列
  • 二叉树的顺序实现-堆
  • 【Maven】Maven主要知识点目录整理
  • Coolmuster Android Assistant: 手机数据管理的全能助手
  • 03-树3 Tree Traversals Again(浙大数据结构PTA习题)
  • Java项目对接redis,客户端是选Redisson、Lettuce还是Jedis?
  • AngularJS Web前端框架:深入探索与应用实践
  • SQL 入门:使用 MySQL 进行数据库操作
  • window安装ffmpeg播放本地摄像头视频
  • 【嵌入式DIY实例】-OLED显示网络时钟
  • 【线程相关知识】
  • 鸿蒙ArkTS声明式开发:跨平台支持列表【透明度设置】 通用属性
  • 【SQL学习进阶】从入门到高级应用(九)
  • Web前端三大主流框架技术分享
  • dockers安装mysql
  • 100道面试必会算法-27-美团2024面试第一题-前缀和矩阵
  • 从摇一摇到弹窗,AD无处不在?为了不再受打扰,推荐几款好用的屏蔽软件,让手机电脑更清爽
  • HackTheBox-Machines--Nibbles
  • 东方博宜1703 - 小明买水果
  • mac电脑用谷歌浏览器对安卓手机H5页面进行inspect
  • 动手学深度学习(Pytorch版)代码实践-深度学习基础-01基础函数的使用
  • vm-bhyve:bhyve虚拟机的管理系统@FreeBSD
  • 【Java】刚刚!突然!紧急通知!垃圾回收!
  • [Algorithm][动态规划][子序列问题][最长递增子序列][摆动序列]详细讲解
  • 【稳定检索】2024年心理学与现代化教育、媒体国际会议(PMEM 2024)
  • 深入了解diffusion model
  • TransmittableThreadLocal原理
  • 华为昇腾310B初体验,OrangePi AIpro开发板使用测评