当前位置: 首页 > news >正文

生成模型实战 | InfoGAN详解与实现

生成模型实战 | InfoGAN详解与实现

    • 0. 前言
    • 1. InfoGAN 原理
      • 1.1 核心思想
      • 1.2 目标函数
    • 2. 实现细节
    • 3. 构建 InfoGAN
      • 3.1 模型构建
      • 3.2 模型训练

0. 前言

在传统的生成对抗网络 (Generative Adversarial Network, GAN) 中,生成器接收一个随机噪声向量作为输入并生成样本,但这个噪声向度的不同维度往往缺乏明确的语义解释。InfoGAN 提出了一种信息理论扩展,能够以无监督的方式学习解耦的、可解释的表示。这种解耦表示对于理解数据生成过程、控制生成样本的特征具有重要意义。
InfoGAN 通过最大化潜编码与生成样本之间的互信息,迫使生成器使用潜编码中的特定维度来控制生成样本的特定语义特征。这种方法不需要任何额外的监督信号,完全通过无监督学习实现特征解耦。本节将详细介绍 InfoGAN 的技术原理,并使用 PyTorch 从零开始实现 InfoGAN

1. InfoGAN 原理

1.1 核心思想

InfoGAN (Information Maximizing Generative Adversarial Network) 是生成对抗网络 (Generative Adversarial Network, GAN) 的一种变体,旨在学习可解释且解耦的潜表示。传统 GAN 的潜空间是高度纠缠的,难以控制生成结果的特定属性。InfoGAN 通过引入互信息最大化的思想,将潜编码分解为两部分:

  • 不可压缩噪声 zzz:用于捕捉数据中的随机变化
  • 潜编码 ccc:用于表示数据中具有语义意义的可解释变量
    • 例如:c=[ccat,ccont]c=[c_{cat},c_{cont}]c=[ccat,ccont],其中 ccatc_{cat}ccat 是离散的使用独热编码的类别编码,ccontc_{cont}ccont 是实值连续编码(例如粗细、倾斜度)

1.2 目标函数

我们希望生成器 G(z,c)G(z,c)G(z,c) 的输出与 ccc 具有较高互信息 I(c;G(z,c))I(c;G(z,c))I(c;G(z,c))。这意味着:给定生成图像 xxx,可以较精确地恢复 ccc,说明 ccc 控制了图像的某些可解释属性。为了强制编码的解耦,InfoGAN 提出了一种针对原始损失函数的正则化函数,该函数将潜在编码 cccG(z,c)G(z,c)G(z,c) 之间的互信息最大化:
I(c;G(z,c))=IG(c;z)I(c;G(z,c))=IG(c;z) I(c;G(z,c))=IG(c;z)
正则化器强制生成器考虑潜编码。在信息论领域,潜编码 cccG(z,c)G(z,c)G(z,c) 之间的互信息定义为:
I(G(c;z)=H(c)−H(c∣G(z,c))I(G(c;z)=H(c)-H(c|G(z,c)) I(G(c;z)=H(c)H(cG(z,c))
其中 H(c)H(c)H(c) 是潜编码 ccc 的熵,而 H(c∣G(z,c))H(c|G(z,c))H(cG(z,c)) 是得到生成器的输出 G(z,c)G(z,c)G(z,c)ccc 的条件熵。最大化互信息意味着在生成得到生成的输出时将 H(c∣G(z,c))H(c|G(z,c))H(cG(z,c)) 最小化或减小潜编码中的不确定性。但是由于估计 H(c∣G(z,c))H(c|G(z,c))H(cG(z,c)) 需要后验分布 p(c∣G(z,c))=p(c∣x)p(c|G(z,c))=p(c|x)p(cG(z,c))=p(cx),因此难以估算 H(c∣G(z,c))H(c|G(z,c))H(cG(z,c))。为了解决这一问题,InfoGAN 引入辅助网络 Q(c∣x)Q(c|x)Q(cx) (近似后验)来给出互信息下界 (variational lower bound):
I(c;G(z,c))≥LI(G,Q)=Ec∼p(c),x∼G(z,c)[logQ(c∣x)]+H(c)I(c;G(z,c)) \ge L_I(G,Q)=E_{c \sim p(c),x \sim G(z,c)}[logQ(c|x)]+H(c) I(c;G(z,c))LI(G,Q)=Ecp(c),xG(z,c)[logQ(cx)]+H(c)
于是我们可以加入最大化该下界的损失(或在最小化框架中加入负对数项)。

2. 实现细节

InfoGAN 架构如下图所示:

InfoGAN架构

InfoGAN 包含三个核心组件:

  • 生成器 G:将噪声 z 和潜编码 c 映射到数据空间
  • 判别器 D:区分真实样本和生成样本(也称伪造样本)
  • 辅助网络 Q:预测潜编码 ccc (通常与 D 共享大部分层)

潜编码 ccc 通常包含:

  • ccatc_{cat}ccat​:离散类别(例如 10 类),用 one-hot 表示。用交叉熵计算判别器预测与真实 ccatc_{cat}ccat​ 的损失
  • ccontc_{cont}ccont​:连续变量(例如 2 维),通常从均匀分布 (-1,1) 采样,InfoGAN 论文中用高斯似然,可以简化为 MSE 来估计(当假设固定方差时)

鉴别器损失函数:
L(D)=−Ex∼pdatalogD(x)−Ez,clog[1−D(G(z,c))]−λI(c;G(z,c))\mathcal L^{(D)} = -\mathbb E_{x\sim p_{data}}logD(x)-\mathbb E_{z,c}log[1 − D(G(z,c))]-\lambda I(c;G(z,c))L(D)=ExpdatalogD(x)Ez,clog[1D(G(z,c))]λI(c;G(z,c))
生成器损失函数:
L(G)=−Ez,clogD(G(z,c))−λI(c;G(z,c))\mathcal L^{(G)} = -\mathbb E_{z,c}logD(G(z,c))-\lambda I(c;G(z,c))L(G)=Ez,clogD(G(z,c))λI(c;G(z,c))
其中 Info 损失权重 λλλ 用于平衡对抗损失与互信息损失。

3. 构建 InfoGAN

接下来,我们将使用 MNIST 数据集训练 InfoGAN

3.1 模型构建

(1) 导入依赖库并设置超参数与设备:

import os
import random
import math
import argparse
from tqdm import tqdmimport torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, datasets, utilsimport matplotlib.pyplot as pltdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)out_dir = "outputs_infogan"
os.makedirs(out_dir, exist_ok=True)

(2) 定义辅助函数 weights_init_normal() 用于对 Conv/Linear 层进行正态初始化,有助于训练稳定性:

def weights_init_normal(m):classname = m.__class__.__name__if classname.find('Conv') != -1 or classname.find('Linear') != -1:nn.init.normal_(m.weight.data, 0.0, 0.02)if hasattr(m, 'bias') and m.bias is not None:nn.init.constant_(m.bias.data, 0)

(3) 定义生成器,输入大小为 (z + c_cat_onehot + c_cont) 的向量,输出 1x28x28 的图像:

class Generator(nn.Module):def __init__(self, input_dim=74, img_channels=1):super().__init__()# 对于 MNIST,输入向量维度可以为 62 (z) + 10 (cat) + 2 (cont) = 74self.net = nn.Sequential(# project and reshapenn.Linear(input_dim, 128*7*7),nn.BatchNorm1d(128*7*7),nn.ReLU(True),# reshape to (128, 7, 7)Reshape((-1, 128, 7, 7)),# upsample to 14x14nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # out 14x14nn.BatchNorm2d(64),nn.ReLU(True),# upsample to 28x28nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),   # out 28x28nn.BatchNorm2d(32),nn.ReLU(True),nn.Conv2d(32, img_channels, kernel_size=3, stride=1, padding=1),nn.Tanh()  # outputs in [-1, 1])def forward(self, x):return self.net(x)class Reshape(nn.Module):def __init__(self, shape):super().__init__()self.shape = shape  # e.g. (-1, C, H, W)def forward(self, x):return x.view(*self.shape)

(4) 定义判别器,Discriminator_Q 同时负责判别(真假)和 Q 的参数预测(共享卷积特征),判别输出为 1logit (未做 sigmoid),用于 BCEWithLogitsLossQ 输出分类 logits (用于交叉熵)与连续变量均值(用于 MSE):

class Discriminator_Q(nn.Module):def __init__(self, img_channels=1, cat_classes=10, cont_dim=2):super().__init__()# 基础的判别器卷积层,用于提取特征self.conv = nn.Sequential(nn.Conv2d(img_channels, 32, 4, 2, 1),  # 14x14nn.LeakyReLU(0.1, inplace=True),nn.Conv2d(32, 64, 4, 2, 1),  # 7x7nn.BatchNorm2d(64),nn.LeakyReLU(0.1, inplace=True),nn.Conv2d(64, 128, 3, 1, 1), # 7x7nn.BatchNorm2d(128),nn.LeakyReLU(0.1, inplace=True),)# 判别器输出(真假)self.disc_head = nn.Sequential(nn.Flatten(),nn.Linear(128*7*7, 1024),nn.LeakyReLU(0.1, inplace=True),nn.Linear(1024, 1)  # 输出 logits)# Q head:预测 c 的后验参数(共享上面 conv 的特征)self.q_shared = nn.Sequential(nn.Flatten(),nn.Linear(128*7*7, 128),nn.LeakyReLU(0.1, inplace=True))# Q 的两个输出:分类 logits 和 连续变量的均值(我们假设方差为常数)self.q_cat = nn.Linear(128, cat_classes)     # 分类 logitsself.q_cont = nn.Linear(128, cont_dim)       # 连续变量均值预测def forward(self, x):features = self.conv(x)disc_logits = self.disc_head(features)q_feat = self.q_shared(features)cat_logits = self.q_cat(q_feat)cont_mu = self.q_cont(q_feat)return disc_logits, cat_logits, cont_mu

3.2 模型训练

(1) 准备 MNIST 数据加载器:

batch_size = 128
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # scale to [-1,1]
])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

(2) 实例化模型,初始化参数、设置损失函数与优化器:

nz = 62  # 噪声维度
cat_classes = 10
cont_dim = 2
gen_input_dim = nz + cat_classes + cont_dimG = Generator(input_dim=gen_input_dim).to(device)
D_Q = Discriminator_Q(cat_classes=cat_classes, cont_dim=cont_dim).to(device)G.apply(weights_init_normal)
D_Q.apply(weights_init_normal)# 损失与优化器
bce_loss = nn.BCEWithLogitsLoss()
ce_loss = nn.CrossEntropyLoss()  # for categorical code
mse_loss = nn.MSELoss()          # for continuous code (regression to true cont values)lr = 2e-4
beta1, beta2 = 0.5, 0.999optimizerD = optim.Adam(D_Q.parameters(), lr=lr, betas=(beta1, beta2))
q_params = list(D_Q.q_shared.parameters()) + list(D_Q.q_cat.parameters()) + list(D_Q.q_cont.parameters())
optimizerG = optim.Adam(list(G.parameters()) + q_params, lr=lr, betas=(beta1, beta2))

(3) 定义辅助函数,用于采样 zzzccatc_{cat}ccatccontc_{cont}ccont

def sample_noise(batch_size, nz):return torch.randn(batch_size, nz, device=device)def sample_categorical(batch_size, cat_classes):# 返回 index 与 one-hot tensorsidx = torch.randint(0, cat_classes, (batch_size,), device=device)one_hot = torch.zeros(batch_size, cat_classes, device=device)one_hot[torch.arange(batch_size), idx] = 1.0return idx, one_hotdef sample_continuous(batch_size, cont_dim):# 从均匀分布 U(-1, 1)return torch.rand(batch_size, cont_dim, device=device) * 2 - 1# 可视化保存函数
def save_image_grid(x, path, nrow=10):# x expected in [-1,1]utils.save_image((x+1)/2.0, path, nrow=nrow)  # convert to [0,1] for saving

(4) 运行训练循环,训练过程先更新 D (让真实图像为 1,生成图像为 0),再更新 G+Q (让生成图像骗过 D (标签为 1),同时最小化 Info 损失使 Q 能正确预测潜编码:

num_epochs = 100
lambda_info = 1.0  # Info loss 权重fixed_z = torch.randn(100, nz, device=device)  # 用于生成固定输出做演示
fixed_cat = torch.zeros(100, cat_classes, device=device)
for i in range(10):fixed_cat[i*10:(i+1)*10, i] = 1.0
fixed_cont = torch.zeros(100, cont_dim, device=device)
fixed_input = torch.cat([fixed_z, fixed_cat, fixed_cont], dim=1)iters = 0
for epoch in range(1, num_epochs+1):pbar = tqdm(train_loader)for real_imgs, _ in pbar:batch_size_cur = real_imgs.size(0)real_imgs = real_imgs.to(device)# 1. 更新 D(判别器)D_Q.zero_grad()# 真实图像real_labels = torch.ones(batch_size_cur, 1, device=device)fake_labels = torch.zeros(batch_size_cur, 1, device=device)disc_real_logits, _, _ = D_Q(real_imgs)lossD_real = bce_loss(disc_real_logits, real_labels)# 生成伪造图像z = sample_noise(batch_size_cur, nz)idx, c_cat_onehot = sample_categorical(batch_size_cur, cat_classes)c_cont = sample_continuous(batch_size_cur, cont_dim)gen_input = torch.cat([z, c_cat_onehot, c_cont], dim=1)fake_imgs = G(gen_input)disc_fake_logits, _, _ = D_Q(fake_imgs.detach())lossD_fake = bce_loss(disc_fake_logits, fake_labels)lossD = lossD_real + lossD_fakelossD.backward()optimizerD.step()# 2. 更新 G 和 Q (联合)G.zero_grad()# Q 的参数包含在 optimizerG 中,因此在 optimizerG.step() 时一并更新disc_fake_logits2, cat_logits, cont_mu = D_Q(fake_imgs)# 对抗目标:让判别器认为生成样本为真实(label=1)adv_loss = bce_loss(disc_fake_logits2, real_labels)# Info 损失:分类用交叉熵(给定真实 idx),连续用 MSE(目标是 c_cont)info_loss_cat = ce_loss(cat_logits, idx)  # idx 是类别标签info_loss_cont = mse_loss(cont_mu, c_cont)info_loss = info_loss_cat + info_loss_cont# 总损失(G 和 Q 的联合损失)gen_loss = adv_loss + lambda_info * info_lossgen_loss.backward()optimizerG.step()iters += 1if iters % 200 == 0:pbar.set_description(f"Epoch[{epoch}/{num_epochs}] D:{lossD.item():.4f} G:{gen_loss.item():.4f} Info:{info_loss.item():.4f}")# 每个 epoch 结束后,生成固定的图片以供观察with torch.no_grad():fake = G(fixed_input).cpu()save_image_grid(fake, os.path.join(out_dir, f"epoch_{epoch:03d}.png"), nrow=10)print(f"Saved sample for epoch {epoch}")

训练过程如下所示:

训练过程

查看训练过程生成的伪造图图像,可以看到随着训练的进行,生成的图像越来越逼真:

生成结果

(5) 训练完成后,生成用于验证“解耦”效果的可视化:

import numpy as npdef visualize_variation(G, nz, cat_classes, cont_dim, out_dir):G.eval()# A) 类别可视化:为每个类别生成 10 张图(z 随机,cont 设置为 0)n_per_class = 10z = torch.randn(cat_classes*n_per_class, nz, device=device)cat = torch.zeros(cat_classes*n_per_class, cat_classes, device=device)for i in range(cat_classes):cat[i*n_per_class:(i+1)*n_per_class, i] = 1.0cont = torch.zeros(cat_classes*n_per_class, cont_dim, device=device)with torch.no_grad():imgs = G(torch.cat([z, cat, cont], dim=1)).cpu()save_image_grid(imgs, os.path.join(out_dir, "viz_by_class.png"), nrow=n_per_class)# B) 连续因子变化:固定类别为 e.g. digit '7'(index 7),对 cont[0] 从 -2 -> 2 扫动chosen_class = 7steps = 11z_fixed = torch.randn(steps, nz, device=device)cat_fixed = torch.zeros(steps, cat_classes, device=device)cat_fixed[:, chosen_class] = 1.0cont_vals = torch.linspace(-2, 2, steps, device=device).unsqueeze(1)  # only vary cont dim 0cont_fixed = torch.zeros(steps, cont_dim, device=device)cont_fixed[:, 0] = cont_vals.squeeze()with torch.no_grad():imgs2 = G(torch.cat([z_fixed, cat_fixed, cont_fixed], dim=1)).cpu()save_image_grid(imgs2, os.path.join(out_dir, f"viz_cont0_class{chosen_class}.png"), nrow=steps)print("Saved visualization images in", out_dir)visualize_variation(G, nz, cat_classes, cont_dim, out_dir)

可视化 A:显示每个类别的生成样本,能初步检验 ccatc_{cat}ccat 是否控制数字类别:

可视化

可视化 B:固定类别并遍历连续因子 cont[0] 看图像如何发生可解释变化(例如笔画粗细、倾斜等):

可视化

http://www.lryc.cn/news/624012.html

相关文章:

  • 停车位 车辆
  • AI出题人给出的Java后端面经(十七)(日更)
  • 【URP】[法线贴图]为什么主要是蓝色的?
  • YoloV9改进策略:Block改进-DCAFE,并行双坐标注意力机制,增强长程依赖与抗噪性-即插即用
  • LangChain4j
  • Java 学习笔记(基础篇4)
  • C++零拷贝网络编程实战:从理论到生产环境的性能优化之路
  • JavaScript 性能优化实战:从评估到落地的全链路指南
  • SparkSQL性能优化实践指南
  • 第16节:自定义几何体 - 从顶点构建3D世界
  • 【FreeRTOS】刨根问底6: 应该如何防止任务栈溢出?
  • 【网络安全】Webshell的绕过——绕过动态检测引擎WAF-缓存绕过(Hash碰撞)
  • 什么是GD库?PHP中7大类64个GD库函数用法详解
  • 日语学习-日语知识点小记-进阶-JLPT-N1阶段蓝宝书,共120语法(3):21-30语法
  • 【AI论文】序曲(PRELUDE):一项旨在考察对长文本语境进行全局理解与推理能力的基准测试
  • PHP静态类self和static用法
  • 6-服务安全检测和防御技术
  • Tomcat Service 服务原理
  • Coin与Token的区别解析
  • java八股文-(spring cloud)微服务篇-参考回答
  • C语言基础:(十六)深入理解指针(6)
  • Centos 更新/修改宝塔版本
  • Rust 入门 生命周期(十八)
  • react echarts图表监听窗口变化window.addEventListener(‘resize’)与ResizeObserver()
  • 音乐创作魔法:解锁和弦与旋律的变化技巧
  • 3D打印——给开发板做外壳
  • 如何做HTTP优化
  • 【JAVA 核心编程】面向对象高级:类变量与方法 抽象类与接口
  • PowerPoint和WPS演示让多个对象通过动画同时出现
  • NY270NY273美光固态闪存NY277NY287