当前位置: 首页 > news >正文

【PyTorch】生成对抗网络

生成对抗网络是什么

概念

Generative Adversarial Nets,简称GAN
GAN:生成对抗网络 —— 一种可以生成特定分布数据的模型
《Generative Adversarial Nets》 Ian J Goodfellow-2014

GAN网络结构

Recent Progress on Generative Adversarial Networks (GANs): A Survey
在这里插入图片描述

How Generative Adversarial Networks and Their Variants Work: An Overview
在这里插入图片描述

Generative Adversarial Networks_ A Survey and Taxonomy

在这里插入图片描述

GAN的训练

训练目的

  1. 对于D:对真样本输出高概率
  2. 对于G:输出使D会给出高概率的数据

GAN 的训练和监督学习训练模式的差异

在监督学习的训练模式中,训练数经过模型得到输出值,然后使用损失函数计算输出值与标签之间的差异,根据差异值进行反向传播,更新模型的参数,如下图所示。
在这里插入图片描述
在 GAN 的训练模式中,Generator 接收随机数得到输出值,目标是让输出值的分布与训练数据的分布接近,但是这里不是使用人为定义的损失函数来计算输出值与训练数据分布之间的差异,而是使用 Discriminator 来计算这个差异。需要注意的是这个差异不是单个数字上的差异,而是分布上的差异。如下图所示。
在这里插入图片描述

具体训练过程

step1:训练D
输入:真实数据加G生成的假数据
输出:二分类概率

step2:训练G
输入:随机噪声z
输出:分类概率——D(G(z))

在这里插入图片描述

DCGAN

Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
在这里插入图片描述

Discriminator:卷积结构的模型
Generator:卷积结构的模型

DCGAN 的定义如下:

from collections import OrderedDict
import torch
import torch.nn as nnclass Generator(nn.Module):def __init__(self, nz=100, ngf=128, nc=3):super(Generator, self).__init__()self.main = nn.Sequential(# input is Z, going into a convolutionnn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(True),# state size. (ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# state size. (ngf*4) x 8 x 8nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# state size. (ngf*2) x 16 x 16nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# state size. (ngf) x 32 x 32nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),nn.Tanh()# state size. (nc) x 64 x 64)def forward(self, input):return self.main(input)def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):for m in self.modules():classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data, w_mean, w_std)elif classname.find('BatchNorm') != -1:nn.init.normal_(m.weight.data, b_mean, b_std)nn.init.constant_(m.bias.data, 0)class Discriminator(nn.Module):def __init__(self, nc=3, ndf=128):super(Discriminator, self).__init__()self.main = nn.Sequential(# input is (nc) x 64 x 64nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf) x 32 x 32nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*2) x 16 x 16nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*4) x 8 x 8nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*8) x 4 x 4nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, input):return self.main(input)def initialize_weights(self, w_mean=0., w_std=0.02, b_mean=1, b_std=0.02):for m in self.modules():classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data, w_mean, w_std)elif classname.find('BatchNorm') != -1:nn.init.normal_(m.weight.data, b_mean, b_std)nn.init.constant_(m.bias.data, 0)
http://www.lryc.cn/news/452397.html

相关文章:

  • Vue3轻松实现前端打印功能
  • SHA-1 是一种不可逆的、固定长度的哈希函数,在 Git 等场景用于生成唯一的标识符来管理对象和数据完整性
  • Activiti7 工作流引擎学习
  • pytorch使用LSTM模型进行股票预测
  • 掌握 C# 异常处理机制
  • 【Redis】Redis Cluster 简单介绍
  • 【EXCEL数据处理】000010 案列 EXCEL文本型和常规型转换。使用的软件是微软的Excel操作的。处理数据的目的是让数据更直观的显示出来,方便查看。
  • golang grpc进阶
  • Java JUC(三) AQS与同步工具详解
  • 使用rust写一个Web服务器——async-std版本
  • C语言复习概要(一)
  • 二、kafka生产与消费全流程
  • 本地搭建OnlyOffice在线文档编辑器结合内网穿透实现远程协作
  • ScrapeGraphAI 大模型增强的网络爬虫
  • PDF转换为TIF,JPG的一个简易工具(含下载链接)
  • Wireshark 解析QQ、微信的通信协议|TCP|UDP
  • 网络编程(5)——模拟伪闭包实现连接的安全回收
  • C#绘制动态曲线
  • 用Python实现运筹学——Day 10: 线性规划的计算机求解
  • [C++]使用C++部署yolov11目标检测的tensorrt模型支持图片视频推理windows测试通过
  • 霍夫曼树及其与B树和决策树的异同
  • CompletableFuture常用方法
  • 本地化测试对游戏漏洞修复的影响
  • 使用rust实现rtsp码流截图
  • Cpp::STL—string类的模拟实现(12)
  • 一文搞懂SentencePiece的使用
  • 一个简单的摄像头应用程序1
  • 通过PHP获取商品详情
  • 【Android】获取备案所需的公钥以及签名MD5值
  • 看480p、720p、1080p、2k、4k、视频一般需要多大带宽呢?