生成式人工智能实战 | WGAN(Wasserstein Generative Adversarial Network, GAN)
生成式人工智能实战 | WGAN
- 0. 前言
- 1. WGAN 与梯度惩罚
- 2. WGAN 工作原理
- 2.1 Wasserstein 损失
- 2.2 Lipschitz 约束
- 2.3 强制 Lipschitz 约束
- 3. 实现 WGAN
- 3.1 数据加载与处理
- 3.2 模型构建
- 3.3 模型训练
0. 前言
生成对抗网络 (Generative Adversarial Network, GAN) 模型训练过程通常会面临一些问题,如模式崩溃(生成器找到了一种能够有效欺骗判别器的输出类型,然后将其输出收缩到这些少数几种模式,忽视其他变化)、梯度消失和收敛速度慢等。Wasserstein GAN
(WGAN
) 引入了 Wasserstein
距离作为损失函数,提供了更平滑的梯度流和更稳定的训练,减轻了模式崩溃等问题。
除了 WGAN
,Progressive GAN
是另一种稳定训练的方法,通过将高分辨率图像生成这一复杂任务分解为逐步生成不同分辨率的图像,从低分辨率开始逐步增加图像的分辨率,并且每个阶段都只关注图像的一个特定细节层次,增强了 GAN
训练的稳定性,从而使学习过程更加可控和高效。
1. WGAN 与梯度惩罚
Wasserstein GAN
(WGAN
) 是一种用于提高 GAN
模型训练稳定性和性能的技术。常规生成对抗网络 (Generative Adversarial Network, GAN) 包含两个组件——生成器和判别器。生成器创建虚假数据,而判别器评估数据是真实还是虚假。训练涉及一个竞争性的零和博弈,其中生成器试图欺骗判别器,而判别器则试图准确区分真实和虚假的数据实例。
研究人员提出使用 Wasserstein
距离(衡量两个分布之间相似性度量)代替二元交叉熵作为损失函数,以通过梯度惩罚项来稳定训练。该技术提供了更平滑的梯度流,并减轻了模式崩溃等问题。WGAN
结构如下图所示,与真实和虚假图像相关的损失是 Wasserstein
损失,而不是常规的二元交叉熵损失。
此外,为了使 Wasserstein
距离能够正确工作,判别器(在 WGAN
中也称评论家 critic
)必须是 1-Lipschitz
连续的,这意味着评论家函数的梯度范数必须处处小于等于 1
。原始的 WGAN
论文提出了权重裁剪来强制执行 Lipschitz
约束。
为了解决权重裁剪的问题,在损失函数中添加梯度惩罚,以更有效地执行 Lipschitz
约束。为了实现带有梯度惩罚的 WGAN
,首先在真实数据点和生成数据点之间的直线上随机采样点。由于真实和虚假图像都有标签,插值图像也会附带一个标签,该标签是两个原始标签的插值值。然后,我们计算评论家输出相对于这些采样点的梯度。最后,将一个惩罚项加入损失函数,这个惩罚项与梯度范数偏离 1
的程度成正比,该惩罚项称为梯度惩罚。也就是说,WGAN
中的梯度惩罚是一种通过更有效地强制执行 Lipschitz
约束,改善训练稳定性和样本质量的技术,解决了原始 WGAN
模型的局限性。
2. WGAN 工作原理
2.1 Wasserstein 损失
首先我们来回顾一下二元交叉嫡, 在训练 DCGAN 判别器和生成器时采用了这种损失函数:
− 1 n ∑ i = 1 n ( y i l o g ( p i ) + ( 1 − y i ) l o g ( 1 − p i ) ) -\frac 1 n \sum_{i=1}^n(y_ilog(p_i)+(1-y_i)log(1-p_i)) −n1i=1∑n(yilog(pi)+(1−yi)log(1−pi))
为了训练 GAN
的判别器 D
,我们根据以下两者计算损失:真实图像的预测 p i = D ( x i ) p_i=D(x_i) pi=D(xi) 与标签 y i = 1 y_i=1 yi=1 之间的误差,以及生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi))与标签 y i = 0 y_i=0 yi=0 之间的误差。因此,对于 GAN
的判别器来说,损失函数最小化的过程可以表示为:
min D − ( E x ∼ p X [ log D ( x ) ] + E z ∼ p Z [ log ( 1 − D ( G ( z ) ) ) ] ) \mathop {\min} \limits_{D}-(\mathbb E_{x\sim p_X}[\log D(x)]+\mathbb E_{z\sim p_Z}[\log (1-D(G(z)))]) Dmin−(Ex∼pX[logD(x)]+Ez∼pZ[log(1−D(G(z)))])
为了训练 GAN
的生成器 G
,我们根据生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 与标签 y i = 1 y_i=1 yi=1 的误差计算损失。因此,对于 GAN
的生成器来说,将损失函数最小化的过程可以表示为:
min G − ( E z ∼ p Z [ log D ( G ( z ) ) ] ) \mathop {\min}\limits_{G}-(\mathbb E_{z\sim p_Z}[\log D(G(z))]) Gmin−(Ez∼pZ[logD(G(z))])
接下来,我们比较上述损失函数与 Wasserstein
损失函数。
Wasserstein
损失 (Wasserstein Loss
) 是用于 Wasserstein GAN
(WGAN
) 的一种损失函数。与传统的二元交叉熵损失函数不同,Wasserstein
损失引入了标签 1
和 -1
,将判别器的输出从概率值转变为分数 (score
),因此,WGAN
的判别器通常也被称为评论家 (critic
),并要求判别器是 1-Lipschitz
连续函数。
具体来说,Wasserstein
损失使用标签 y i = 1 y_i=1 yi=1 和 y i = − 1 y_i=-1 yi=−1 代替 y i = 1 y_i=1 yi=1 和 y i = 0 y_i=0 yi=0,同时还需要移除判别器最后一层的 Sigmoid
激活函数,如此一来预测结果 p i p_i pi 就不一定在 [ 0 , 1 ] [0,1] [0,1] 范围内了,它可以是 [ − ∞ , ∞ ] [-∞,∞] [−∞,∞] 范围内的任何值。Wasserstein
损失的定义如下:
− 1 n ∑ i = 1 n ( y i p i ) -\frac 1 n∑_{i=1}^n(y_ip_i) −n1i=1∑n(yipi)
在训练 WGAN
的判别器 D
时,我们将计算以下损失:判别器对真实图像的预测 p i = D ( x i ) p_i=D(x_i) pi=D(xi) 与标签 y i = 1 y_i=1 yi=1 之间的误差,判别器对生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 与标签 y i = − 1 y_i=-1 yi=−1 之间的误差。因此,对于 WGAN
判别器,最小化损失函数的过程可以表示为:
min D − ( E x ∼ p X [ D ( x ) ] − E z ∼ p Z [ D ( G ( z ) ) ] ) \mathop {\min}\limits_ D - (\mathbb E_{x\sim p_X}[D(x)] - \mathbb E_{z\sim p_Z}[D(G(z))]) Dmin−(Ex∼pX[D(x)]−Ez∼pZ[D(G(z))])
换句话说,WGAN
判别器试图最大化其对真实图像的预测和生成图像的预测之间的差异,且真实图像的得分更高。
而对于 WGAN
生成器 G
的训练,我们根据判别器对生成图像的预测 p i = D ( G ( z i ) ) p_i=D(G(z_i)) pi=D(G(zi)) 与标签 y i = 1 y_i=1 yi=1 计算损失。因此,对于 WGAN
生成器,最小化损失函数可以表示为:
min G − ( E z ∼ p Z [ D ( G ( z ) ) ] ) \mathop {\min}\limits_ G - (\mathbb E_{z\sim p_Z}[D(G(z))]) Gmin−(Ez∼pZ[D(G(z))])
换句话说,WGAN
生成器试图生成被判别器以极高分数判定为真实图像的图像(即,令判别器认为它们是真实的)。
2.2 Lipschitz 约束
由于我们允许判别器输出 [ − ∞ , ∞ ] [-∞,∞] [−∞,∞] 范围内的任意值,而不是按照 Sigmoid
函数那样将输出限制在 [ 0 , 1 ] [0,1] [0,1] 范围内,因此 Wasserstein
损失可能会非常大。因此,为了使 Wasserstein
损失函数正常工作,需要对判别器进行额外约束,即 1-Lipschitz
连续性约束。判别器是一个将图像转换为预测的函数 D
,如果对于任意两个输人图像 x 1 x_1 x1 和 x 2 x_2 x2,判别器函数 D
满足以下不等式,则该函数为 1-Lipschitz
连续:
∣ D ( x 1 ) − D ( x 2 ) ∣ ∣ x 1 − x 2 ∣ ≤ 1 \frac {|D(x_1) - D(x_2)|}{|x_1 - x_2|} ≤ 1 ∣x1−x2∣∣D(x1)−D(x2)∣≤1
其中, ∣ x 1 − x 2 ∣ |x_1 - x_2| ∣x1−x2∣ 表示两个图像的平均像素之差的绝对值, ∣ D ( x 1 ) − D ( x 2 ) ∣ |D(x_1) - D(x_2)| ∣D(x1)−D(x2)∣ 表示判别器预测之间的绝对值。这意味着判别器的预测变化速率在任何情况下都是有界的(即梯度的绝对值不能大于 1
)。可以在下图中的 Lipschitz
连续的一维函数中看到,无论将圆锥放在任何位置,曲线都不会进入圆锥内部。换句话说,曲线上任何一点的上升或下降速度都是有限的。
2.3 强制 Lipschitz 约束
在原始的 WGAN
论文中,作者通过在每个训练结束后将判别器的权重裁剪到一个较小范围内 [ − 0.01 , 0.01 ] [-0.01, 0.01] [−0.01,0.01] 来强制执行 Lipschitz
约束。
由于我们裁剪了判别器的权重,判别器的学习能力大大降低,因此,事实上,权重裁剪并不是一种理想的强制 Lipschitz
约束的方式。一个强大的判别器对于 WGAN
的成功至关重要,因为如果没有准确的梯度,生成器无法学习如何调整其权重以产生更好的样本。
因此,研究人员提出了许多其他方法来强制执行 Lipschitz
约束,并提高 WGAN
学习复杂特征的能力。其中一种方法是带有梯度惩罚 (Gradient Penalty
) 的 Wasserstein GAN
。
通过在判别器的损失函数中包含一个梯度惩罚项来直接强制执行 Lipschitz
约束,如果梯度范数偏离 1
时,该项会惩罚模型,从而使训练过程更加稳定。接下来,将这个额外的梯度惩罚项加入到判别器损失函数中。
3. 实现 WGAN
本节中,我们将继续使用 Celeb A 人脸图像数据集构建 WGAN
:
3.1 数据加载与处理
from torchvision import transforms
import torchvision.utils as vutils
import cv2, numpy as np
import torch
import os
from glob import glob
from PIL import Image
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"transform=transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])class Faces(Dataset):def __init__(self, folder):super().__init__()self.folder = folderself.images = sorted(glob(folder))def __len__(self):return len(self.images)def __getitem__(self, ix):image_path = self.images[ix]image = Image.open(image_path)image = transform(image)return imageds = Faces(folder='cropped_faces/*.jpg')
dataloader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=8)
3.2 模型构建
(1) 在 WGAN
中,将判别器网络称为评论家 (Critic
)。评论家评估输入并给出一个介于 − ∞ −∞ −∞ 和 ∞ ∞ ∞ 之间的评分。评分越高,表示输入来自训练集的可能性越大(即越可能是真实数据):
class Critic(nn.Module):def __init__(self):super(Critic, self).__init__()self.conv1 = nn.Conv2d(channels, 32, 4, 2, 1) # 输入 64x64x3 图像,输出 32x32x32self.conv2 = nn.Conv2d(32, 64, 4, 2, 1) # 输出 16x16x64self.conv3 = nn.Conv2d(64, 128, 4, 2, 1) # 输出 8x8x128self.conv4 = nn.Conv2d(128, 256, 4, 2, 1) # 输出 4x4x256self.fc = nn.Linear(256 * 4 * 4, 1) # 最终输出一个标量def forward(self, x):x = F.leaky_relu(self.conv1(x), 0.2)x = F.leaky_relu(self.conv2(x), 0.2)x = F.leaky_relu(self.conv3(x), 0.2)x = F.leaky_relu(self.conv4(x), 0.2)x = x.view(x.size(0), -1) # 扁平化x = self.fc(x)return x
(2) 生成器的任务是创建数据实例,以便它们能被评论家评估并获得高分:
class Generator(nn.Module):def __init__(self, latent_dim):super(Generator, self).__init__()self.fc = nn.Linear(latent_dim, 256 * 4 * 4) # 隐藏层,输出 4x4x256 的特征图self.conv1 = nn.ConvTranspose2d(256, 128, 4, 2, 1) # 转置卷积层,输出 8x8x128self.conv2 = nn.ConvTranspose2d(128, 64, 4, 2, 1) # 转置卷积层,输出 16x16x64self.conv3 = nn.ConvTranspose2d(64, 32, 4, 2, 1) # 转置卷积层,输出 32x32x32self.conv4 = nn.ConvTranspose2d(32, channels, 4, 2, 1) # 输出 64x64x3 的图像def forward(self, z):x = F.relu(self.fc(z)).view(-1, 256, 4, 4) # 通过全连接层然后 reshapex = F.relu(self.conv1(x))x = F.relu(self.conv2(x))x = F.relu(self.conv3(x))x = torch.tanh(self.conv4(x)) # 使用tanh激活输出图像return x
(3) 评论家的损失函数包含三个部分:
critic_value(fake) − critic_value(real) + weight × GradientPenalty
第一项,critic_value(fake)
,表示如果图像是虚假的,评论家的目标是将其识别为虚假图像,并给出较低评分。第二项,− critic_value(real)
,表示如果图像是真的,评论家的目标是将其识别为真实图像,并给出较高评分。此外,评论家还希望最小化梯度惩罚项,weight × GradientPenalty
,其中 weight
是一个常量,用于确定对梯度范数偏离 1
的程度施加多少惩罚。梯度惩罚的计算方法如下所示:
def compute_gradient_penalty(Critic, real_samples, fake_samples):# 获取batch的大小batch_size = real_samples.size(0)# 在 [0, 1] 范围内进行插值epsilon = torch.rand(batch_size, 1, 1, 1).to(device)interpolated_images = epsilon * real_samples + (1 - epsilon) * fake_samplesinterpolated_images.requires_grad_(True)# 计算判别器输出d_interpolated = Critic(interpolated_images)# 计算梯度gradients = torch.autograd.grad(outputs=d_interpolated, inputs=interpolated_images,grad_outputs=torch.ones_like(d_interpolated).to(device),create_graph=True, retain_graph=True, only_inputs=True)[0]# 计算梯度的L2范数gradients = gradients.view(batch_size, -1)gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()return gradient_penalty
3.3 模型训练
(1) 定义超参数,并实例化生成器、评论家和优化器:
batch_size = 64
lr = 0.0001
beta1 = 0.5
beta2 = 0.999
epochs = 200
latent_dim = 100generator = Generator(latent_dim).to(device)
critic = Critic().to(device)opt_C = optim.RMSprop(critic.parameters(), lr=5e-5) # WGAN 推荐 RMSprop
opt_G = optim.RMSprop(generator.parameters(), lr=5e-5)
(2) 创建函数 save_generated_images
,定期观察生成的图像效果:
def save_generated_images(epoch, generator, latent_dim, save_dir='./generated_images/'):"""保存生成的图像并可视化"""z = torch.randn(32, latent_dim).to(device) # 生成16个随机噪声generated_images = generator(z) # 生成图像generated_images = generated_images.cpu().detach() # 转到CPU并分离图像张量# 将图像转换为适合显示的形式grid = torchvision.utils.make_grid(generated_images, nrow=16, normalize=True, padding=2)# 显示图像plt.figure(figsize=(8, 8))plt.imshow(grid.permute(1, 2, 0))plt.axis('off')plt.title(f"Generated Images at Epoch {epoch}")plt.show()# 保存图像if not os.path.exists(save_dir):os.makedirs(save_dir)plt.imsave(f"{save_dir}/generated_epoch_{epoch}.png", grid.permute(1, 2, 0).numpy())
(3) 训练模型 200
个 epoch
:
# 训练循环
for epoch in range(epochs):for i, real_images in enumerate(dataloader):real_images = real_images.to(device)# 每次更新 G 前,训练 Critic 5 次for _ in range(5):opt_C.zero_grad()# 生成假图像z = torch.randn(real_images.size(0), latent_dim).to(device)fake_images = generator(z)# 判别器对真实图像和假图像的判别real_validity = critic(real_images)fake_validity = critic(fake_images.detach())# 计算梯度惩罚gradient_penalty = compute_gradient_penalty(critic, real_images, fake_images)# 判别器损失d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + 10 * gradient_penaltyd_loss.backward(retain_graph=True)opt_C.step()# 训练生成器opt_G.zero_grad()# 生成器损失fake_validity = critic(fake_images)g_loss = -torch.mean(fake_validity)g_loss.backward()opt_G.step()# 每5个epoch可视化生成图像if epoch % 1 == 0:print(f"Epoch [{epoch}/{epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")save_generated_images(epoch, generator, latent_dim)
(3) 观察 WGAN
生成图像: