当前位置: 首页 > article >正文

生成对抗网络(GAN)原理

生成对抗网络(GAN)原理

  • 介绍
    • 示例代码
    • 一、GAN 的基本结构
      • 1. 生成器(Generator,记作 G)
      • 2. 判别器(Discriminator,记作 D)
    • 二、对抗过程(博弈思想)
    • 三、训练过程
    • 四、存在的问题与改进方向
      • 1. 模式崩溃(Mode Collapse)
      • 2. 训练不稳定
      • 3. 衡量指标困难
    • 五、GAN 的改进与变种
    • 六、GAN 的应用
    • ✅ DCGAN 与普通 GAN 的区别
    • ✅ DCGAN 示例(基于 MNIST)
      • 🔧 安装依赖
      • 🧠 DCGAN 架构代码(Generator + Discriminator)
    • 🧪 使用说明

介绍

示例代码

生成对抗网络(Generative Adversarial Network,GAN)是由 Ian Goodfellow 等人在 2014 年提出的一种深度生成模型。它通过两个神经网络之间的博弈(对抗)过程,学习数据的生成分布,从而生成以假乱真的数据(如图像、语音等)。GAN 是近年来生成模型领域的重要突破,广泛应用于图像生成、风格迁移、图像修复等任务中。


一、GAN 的基本结构

GAN 主要由两个部分组成:

1. 生成器(Generator,记作 G)

  • 目标:生成尽可能真实的数据,欺骗判别器。
  • 输入:随机噪声向量(一般从正态分布或均匀分布中采样)
  • 输出:“伪造”的样本,尽可能与真实样本相似。

2. 判别器(Discriminator,记作 D)

  • 目标:判断输入数据是真实的样本还是生成器生成的伪造样本。
  • 输入:真实样本或生成样本
  • 输出:一个概率值,表示输入是“真实”的概率。

二、对抗过程(博弈思想)

GAN 的训练过程是一个零和博弈(min-max game):

  • 生成器试图最小化判别器对生成样本的识别能力;
  • 判别器试图最大化识别真实样本与生成样本的能力。

这个过程可以表示为一个最优化问题:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

其中:

  • p d a t a ( x ) p_{data}(x) pdata(x) 是真实数据的分布;
  • p z ( z ) p_z(z) pz(z) 是生成器输入噪声的分布(如高斯分布);
  • D ( x ) D(x) D(x) 是判别器输出 x 是真实数据的概率;
  • G ( z ) G(z) G(z) 是生成器输出的伪造样本。

三、训练过程

  1. 固定生成器 G,训练判别器 D

    • 给 D 一部分真实样本(标签为 1);
    • 给 D 一部分 G 生成的样本(标签为 0);
    • 通过交叉熵损失训练 D,使其能区分真假样本。
  2. 固定判别器 D,训练生成器 G

    • 通过 G 生成假样本;
    • D 会判断其为假;
    • G 的目标是欺骗 D,即最大化 D ( G ( z ) ) D(G(z)) D(G(z)),让 D 判错;
    • 通常优化的是 log ⁡ D ( G ( z ) ) \log D(G(z)) logD(G(z)) 的反函数,例如 log ⁡ ( 1 − D ( G ( z ) ) ) \log(1 - D(G(z))) log(1D(G(z))) 或更稳定的变体(如使用 feature matching 或 Wasserstein loss)。
  3. 交替训练 D 和 G,直到生成器生成的样本无法被判别器区分为假(判别器输出接近 0.5)。


四、存在的问题与改进方向

1. 模式崩溃(Mode Collapse)

生成器只学会生成一小部分模式样本,导致多样性丢失。

2. 训练不稳定

D 和 G 的能力不均衡、学习率不合适等因素可能导致 GAN 训练震荡或失败。

3. 衡量指标困难

GAN 的损失函数不能很好地反映生成质量,因此通常使用 FID、IS 等指标辅助评估。


五、GAN 的改进与变种

为了克服原始 GAN 的不足,研究人员提出了许多变种:

名称简介
DCGAN使用卷积神经网络的 GAN,适合图像数据生成
WGANWasserstein GAN,引入 Wasserstein 距离,解决训练不稳定问题
WGAN-GP在 WGAN 基础上加上梯度惩罚项,提高训练稳定性
CGAN条件 GAN,可以控制生成样本的类别(如生成特定数字)
CycleGAN用于图像风格转换(如马<->斑马、夏天<->冬天)
StyleGAN高质量人脸图像生成的里程碑,支持精细控制生成风格
BigGAN更大规模、更高分辨率的图像生成模型

六、GAN 的应用

  1. 图像生成(人脸、卡通、合成图像)
  2. 图像风格转换(如照片转油画)
  3. 图像修复与超分辨率重建
  4. 数据增强(医学图像、遥感等领域)
  5. 音频合成、文本生成(结合 RNN、Transformer)
  6. 深度伪造(Deepfake)

下面是一个使用 PyTorch 实现的 DCGAN(Deep Convolutional GAN)完整示例,用于在 MNISTCIFAR-10 数据集上生成图像。与普通 GAN 相比,DCGAN 使用卷积神经网络结构,更适合图像处理,生成效果更清晰、真实。


✅ DCGAN 与普通 GAN 的区别

模块普通 GANDCGAN
生成器结构全连接层(MLP)转置卷积(ConvTranspose2d)
判别器结构全连接层卷积层(Conv2d)
激活函数ReLU + SigmoidLeakyReLU + Sigmoid / No sigmoid
正则化无 / BatchNorm1dBatchNorm2d

✅ DCGAN 示例(基于 MNIST)

🔧 安装依赖

pip install torch torchvision matplotlib

🧠 DCGAN 架构代码(Generator + Discriminator)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import os# 超参数
img_size = 64
channels = 1  # 改为 3 可用于 CIFAR-10
latent_dim = 100
batch_size = 128
lr = 0.0002
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 创建目录
os.makedirs("dcgan_images", exist_ok=True)# 数据预处理(MNIST 被 resize 成 64x64 )
transform = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize([0.5], [0.5])
])
dataloader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transform),batch_size=batch_size, shuffle=True
)# 生成器(使用转置卷积)
class Generator(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(# 输入是 latent_dim 向量,输出 1024nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),nn.BatchNorm2d(512),nn.ReLU(True),nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),nn.BatchNorm2d(256),nn.ReLU(True),nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),nn.BatchNorm2d(128),nn.ReLU(True),nn.ConvTranspose2d(128, channels, 4, 2, 1, bias=False),nn.Tanh())def forward(self, z):z = z.view(z.size(0), latent_dim, 1, 1)return self.model(z)# 判别器(使用卷积)
class Discriminator(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Conv2d(channels, 128, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(128, 256, 4, 2, 1, bias=False),nn.BatchNorm2d(256),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(256, 512, 4, 2, 1, bias=False),nn.BatchNorm2d(512),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(512, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, x):return self.model(x).view(-1, 1).squeeze(1)# 初始化模型
G = Generator().to(device)
D = Discriminator().to(device)# 损失和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))# 训练 DCGAN
for epoch in range(epochs):for i, (real_imgs, _) in enumerate(dataloader):real_imgs = real_imgs.to(device)b_size = real_imgs.size(0)# 标签valid = torch.ones(b_size, device=device)fake = torch.zeros(b_size, device=device)# ========== 训练判别器 ==========optimizer_D.zero_grad()real_loss = criterion(D(real_imgs), valid)z = torch.randn(b_size, latent_dim, device=device)gen_imgs = G(z)fake_loss = criterion(D(gen_imgs.detach()), fake)d_loss = real_loss + fake_lossd_loss.backward()optimizer_D.step()# ========== 训练生成器 ==========optimizer_G.zero_grad()g_loss = criterion(D(gen_imgs), valid)g_loss.backward()optimizer_G.step()if i % 100 == 0:print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")# 保存生成图像with torch.no_grad():z = torch.randn(64, latent_dim, device=device)gen_imgs = G(z)grid = make_grid(gen_imgs, nrow=8, normalize=True)save_image(grid, f"dcgan_images/{epoch:03d}.png")print("DCGAN 训练完成,图像保存在 dcgan_images 文件夹中。")

🧪 使用说明

  • 若想改用 彩色图像(如 CIFAR-10),需:

    • channels = 3
    • 使用 datasets.CIFAR10 替代 MNIST
    • 修改 transforms.Normalize([0.5]*3, [0.5]*3)
http://www.lryc.cn/news/2383655.html

相关文章:

  • 【SpringBoot实战指南】使用 Spring Cache
  • centos8 配置网桥,并禁止kvm默认网桥
  • C++:list容器,deque容器
  • 【Node.js】全栈开发实践
  • 自定义类型-联合体
  • Qt项目开发中所遇
  • ubuntu sh安装包的安装方式
  • Redis语法大全
  • OpenAI宣布:核心API支持MCP,助力智能体开发
  • 我的爬虫夜未眠:一场与IP限流的攻防战
  • git:The following paths are ignored by one of your
  • 算法--js--组合总和
  • 微服务中的 AKF 拆分原则:构建可扩展系统的核心方法论
  • vue element-plus 集成多语言
  • 如何测试JWT的安全性:全面防御JSON Web Token的安全漏洞
  • 车载网关策略 --- 车载网关重置前的请求转发机制
  • EtpBot:安卓自动化脚本开发神器
  • 连锁企业管理系统对门店运营的促进作用
  • 现代生活健康养生新策略
  • 车载以太网网络测试-27【SOME/IP-SD简述】
  • 云南安全员考试报名需要具备哪些条件?
  • Android Binder线程池饥饿与TransactionException:从零到企业级解决方案(含实战代码+调试技巧)
  • FFmpeg 超级详细安装与配置教程(Windows 系统)
  • 【Redis8】最新安装版与手动运行版
  • PyQt 探索QMainWindow:打造专业的PyQt5主窗
  • Spring Boot 集成 Elasticsearch【实战】
  • 06算法学习_58. 区间和
  • 如何在Java中进行PDF合并
  • Python爬虫之路(14)--playwright浏览器自动化
  • Python开启智能之眼:OpenCV+深度学习实战