当前位置: 首页 > news >正文

Mindspore框架DCGAN模型实现漫画头像生成|(二)DCGAN模型构建

Mindspore框架DCGAN模型实现漫画头像生成

  1. Mindspore框架DCGAN模型实现漫画头像生成|(一)漫画头像数据集准备
  2. Mindspore框架DCGAN模型实现漫画头像生成|(二)DCGAN模型构建
  3. Mindspore框架DCGAN模型实现漫画头像生成|(三)DCGAN模型训练和推理
  4. Mindspore框架DCGAN模型实现漫画头像生成|(四)应用程序生成实践

Mindspore框架DCGAN模型实现漫画头像生成|(二)DCGAN模型构建

DCGAN,全称是 Deep Convolution Generative Adversarial Networks,深度卷积生成对抗网络

1. DCGAN模型特点

  • make GAN + CNN more stable and deeper,能够产生更高分辨率的图像;
  • 全卷积网络(all convolutional net):用步幅卷积(strided convolutions)替代确定性空间池化函数(deterministic spatial pooling functions)(比如最大池化),让网络自己学习downsampling方式。作者对 generator 和 discriminator 都采用了这种方法。
  • 取消全连接层:使用 全局平均池化(global average pooling)替代 fully connected layer。global average pooling会降低收敛速度,但是可以提高模型的稳定性。GAN的输入采用均匀分布初始化,可能会使用全连接层(矩阵相乘),然后得到的结果可以reshape成一个4 dimension的tensor,然后后面堆叠卷积层即可;对于鉴别器,最后的卷积层可以先flatten,然后送入一个sigmoid分类器。
  • 批归一化(Batch Normalization):BN 被证明是深度学习中非常重要的 加速收敛 和 减缓过拟合 的手段。这样有助于解决 poor initialization 问题并帮助梯度流向更深的网络。防止G把所有rand input都折叠到一个点,同时防止样本震荡和模型的不稳定,只对生成器(G)的输出层和鉴别器(D)的输入层使用BN。
  • Leaky Relu 激活函数: 生成器(G),输出层使用tanh 激活函数,其余层使用relu 激活函数。鉴别器(D),都采用leaky rectified activation。
  • DCGAN生成器G的结构如下:
    在这里插入图片描述

2. 构造网络:生成器G

生成器G的功能是将隐向量z映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的 RGB 图像。

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)# 通过输入部分中设置的nz、ngf和nc来影响代码中的生成器结构。
class Generator(nn.Cell):"""DCGAN网络生成器"""def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())def construct(self, x):return self.generator(x)generator = Generator()

注意:nz是隐向量z的长度,ngf与通过生成器传播的特征图的大小有关,nc是输出图像中的通道数
在这里插入图片描述

2. 构造网络:判别器D

判别器D是一个二分类网络模型,输出判定该图像为真实图的概率。形如:
在这里插入图片描述

通过一系列的Conv2d、BatchNorm2d和LeakyReLU层对其进行处理,最后通过Sigmoid激活函数得到最终概率。

class Discriminator(nn.Cell):"""DCGAN网络判别器"""def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)discriminator = Discriminator()

模型结构输出:
在这里插入图片描述

http://www.lryc.cn/news/410885.html

相关文章:

  • mongo-csharp-driver:MongoDB官方的C#客户端驱动程序!
  • 网络流量分析>>pcapng文件快速分析有用价值解析
  • 【大模型系列篇】Vanna-ai基于检索增强(RAG)的sql生成框架
  • 【Nacos安装】
  • js、ts、argular、nodejs学习心得
  • 【Unity】RPG2D龙城纷争(十八)平衡模拟器
  • java.lang.IllegalStateException: Duplicate key InventoryDetailDO
  • Python使用selenium访问网页完成登录——装饰器重试机制汇总
  • “微软蓝屏”事件引发的深度思考:网络安全与系统稳定性的挑战与应对
  • 2024.07纪念一 debezium : spring-boot结合debezium
  • mysql怎么查询json里面的字段
  • C++ 右值 左值引用
  • 「JavaEE」Spring IoC 1:Bean 的存储
  • springBoot快速搭建WebSocket
  • 掌控授权的艺术:Laravel自定义策略模式深度解析
  • Git操作指令(随时更新)
  • SpringSecurity自定义登录方式
  • 黑神话悟空是什么游戏 黑神话悟空配置要求 黑神话悟空好玩吗值得买吗 黑神话悟空苹果电脑可以玩吗
  • 深入浅出消息队列----【延迟消息的实现原理】
  • npm提示 certificate has expired 证书已过期 已解决
  • KEIL如何封装文件成lib
  • 【python】OpenCV—Faster Video File FPS
  • JavaScript变量的类型转换
  • 如何申请免费SSL证书以消除访问网站显示连接不安全提醒
  • 关于P2P(点对点)
  • 前端怎么本地起一个服务查看本地文件
  • 建造者模式(Builder Pattern)
  • 【MySQL】索引 【下】{聚簇索引VS非聚簇索引/创建主键索引/全文索引的创建/索引创建原则}
  • 论文快过(图像配准|Coarse_LoFTR_TRT)|适用于移动端的LoFTR算法的改进分析 1060显卡上45fps
  • 免费发送邮件两种接口方式:SMTP和邮件API