当前位置: 首页 > news >正文

生成式人工智能实战 | 自注意力生成对抗网络(Self-Attention Generative Adversarial Network, SAGAN)

生成式人工智能实战 | 自注意力生成对抗网络

    • 0. 前言
    • 1. SAGAN 核心原理
      • 1.1 自注意力机制
      • 1.2 谱归一化
    • 2. 实现 SAGAN
      • 2.1 生成器
      • 2.2 判别器
    • 3. 模型训练
      • 3.1 数据加载
      • 3.2 训练流程

0. 前言

自注意力生成对抗网络 (Self-Attention Generative Adversarial Network, SAGAN) 通过在传统深度卷积 GAN 中嵌入自注意力机制,有效捕捉图像中远距离的依赖关系,从而生成更具全局一致性和细节丰富的图像。SAGAN 在生成器和判别器中均引入自注意力模块,并结合谱归一化 (Spectral Normalization)、条件批归一化 (Conditional Batch Normalization)、投影鉴别器 (Projection Discriminator)及铰链损失 (Hinge Loss),显著提升了训练的稳定性与样本质量。本节将全面介绍 SAGAN 的核心原理与并使用 PyTorch 实现 SAGAN 模型。

1. SAGAN 核心原理

1.1 自注意力机制

传统卷积神经网络主要依赖局部感受野,难以捕捉图像中跨区域的全局结构信息,而深层堆叠卷积层虽具备理论潜力,但优化难度大且统计鲁棒性不足。
SAGAN 通过在特征图上计算自注意力 (Self-Attention),使得每个位置的输出既依赖其局部邻域信息,又能够利用所有位置的全局线索,从而改善生成结果的全局一致性。
在图像场景下,自注意力模块首先将输入特征图通过三个 1×1 卷积分别映射为查询 (Query)、键 (Key)、值 (Value) 三组特征,然后按下述步骤计算注意力输出:

  • 将查询和键张量经矩阵乘法计算注意力权重矩阵,并通过 softmax 归一化
  • 将注意力权重与值张量相乘得到加权特征表示
  • 乘以可学习缩放因子 γγγ 并与原始特征相加,实现残差连接,初始时 γ=0γ=0γ=0,网络可先依赖局部信息再逐步学习非局部依赖

下图显示了 SAGAN 中的注意力模块,其中 θθθφφφggg 对应于键,查询和值:

自注意力机制

接下来,实现自注意力机制,先将输入映射到查询/键/值空间,计算注意力矩阵,再将加权值与原始输入融合:

class Self_Attn(nn.Module):""" Self attention Layer"""def __init__(self,in_dim,activation):super(Self_Attn,self).__init__()self.chanel_in = in_dimself.activation = activationself.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)self.gamma = nn.Parameter(torch.zeros(1))self.softmax  = nn.Softmax(dim=-1) #def forward(self,x):"""inputs :x : input feature maps( B X C X W X H)returns :out : self attention value + input feature attention: B X N X N (N is Width*Height)"""m_batchsize,C,width ,height = x.size()proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)energy =  torch.bmm(proj_query,proj_key) # transpose checkattention = self.softmax(energy) # BX (N) X (N) proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X Nout = torch.bmm(proj_value,attention.permute(0,2,1) )out = out.view(m_batchsize,C,width,height)out = self.gamma*out + xreturn out,attention

1.2 谱归一化

为稳定对抗训练,SAGAN 对生成器和判别器的所有卷积权重均施加谱归一化,强制权重矩阵的最大奇异值为 1,从而满足 Lipschitz 连续性约束,抑制梯度爆炸或消失。以下是执行频谱归一化的步骤:

  • 卷积层中的权重是一个 4 维张量,因此第一步是将其重塑为 2D 矩阵,在这里我们保留权重的最后一个维度。现在,权重的形状为 (H×W, C)
  • N(0,1) 初始化向量 uuu
  • for 循环中,计算以下内容:
    • 用矩阵转置和矩阵乘法计算 V=(W⊤)UV =(W^\top)UV=(W)U
    • 用其 L2 范数归一化 VVV,即 V=V∣∣V∣∣2V = \frac {V}{||V||_2}V=∣∣V2V
    • 计算 U=WVU = WVU=WV
    • 用 L2 范数归一化 UUU,即 U=U∣∣U∣∣2U =\frac {U}{||U||_2}U=∣∣U2U
  • 计算谱范数为 U⊤WVU^\top WVUWV
  • 最后,将权重除以谱范数
def l2normalize(v, eps=1e-12):return v / (v.norm() + eps)class SpectralNorm(nn.Module):def __init__(self, module, name='weight', power_iterations=1):super(SpectralNorm, self).__init__()self.module = moduleself.name = nameself.power_iterations = power_iterationsif not self._made_params():self._make_params()def _update_u_v(self):u = getattr(self.module, self.name + "_u")v = getattr(self.module, self.name + "_v")w = getattr(self.module, self.name + "_bar")height = w.data.shape[0]for _ in range(self.power_iterations):v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))sigma = u.dot(w.view(height, -1).mv(v))setattr(self.module, self.name, w / sigma.expand_as(w))def _made_params(self):try:u = getattr(self.module, self.name + "_u")v = getattr(self.module, self.name + "_v")w = getattr(self.module, self.name + "_bar")return Trueexcept AttributeError:return Falsedef _make_params(self):w = getattr(self.module, self.name)height = w.data.shape[0]width = w.view(height, -1).data.shape[1]u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)u.data = l2normalize(u.data)v.data = l2normalize(v.data)w_bar = Parameter(w.data)del self.module._parameters[self.name]self.module.register_parameter(self.name + "_u", u)self.module.register_parameter(self.name + "_v", v)self.module.register_parameter(self.name + "_bar", w_bar)def forward(self, *args):self._update_u_v()return self.module.forward(*args)

2. 实现 SAGAN

2.1 生成器

生成器以噪声作为输入并经过多个上采样和卷积块,同时在中层插入自注意力,以生成具有全局一致性的细节:

class Generator(nn.Module):"""Generator."""def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64):super(Generator, self).__init__()self.imsize = image_sizelayer1 = []layer2 = []layer3 = []last = []repeat_num = int(np.log2(self.imsize)) - 3mult = 2 ** repeat_num # 8layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4)))layer1.append(nn.BatchNorm2d(conv_dim * mult))layer1.append(nn.ReLU())curr_dim = conv_dim * multlayer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))layer2.append(nn.ReLU())curr_dim = int(curr_dim / 2)layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))layer3.append(nn.ReLU())if self.imsize == 64:layer4 = []curr_dim = int(curr_dim / 2)layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))layer4.append(nn.BatchNorm2d(int(curr_dim / 2)))layer4.append(nn.ReLU())self.l4 = nn.Sequential(*layer4)curr_dim = int(curr_dim / 2)self.l1 = nn.Sequential(*layer1)self.l2 = nn.Sequential(*layer2)self.l3 = nn.Sequential(*layer3)last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1))last.append(nn.Tanh())self.last = nn.Sequential(*last)self.attn1 = Self_Attn( 128, 'relu')self.attn2 = Self_Attn( 64,  'relu')def forward(self, z):z = z.view(z.size(0), z.size(1), 1, 1)out=self.l1(z)out=self.l2(out)out=self.l3(out)out,p1 = self.attn1(out)out=self.l4(out)out,p2 = self.attn2(out)out=self.last(out)return out, p1, p2

2.2 判别器

判别器同样使用引入自注意力以捕捉全局依赖:

class Discriminator(nn.Module):"""Discriminator, Auxiliary Classifier."""def __init__(self, batch_size=64, image_size=64, conv_dim=64):super(Discriminator, self).__init__()self.imsize = image_sizelayer1 = []layer2 = []layer3 = []last = []layer1.append(SpectralNorm(nn.Conv2d(3, conv_dim, 4, 2, 1)))layer1.append(nn.LeakyReLU(0.1))curr_dim = conv_dimlayer2.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))layer2.append(nn.LeakyReLU(0.1))curr_dim = curr_dim * 2layer3.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))layer3.append(nn.LeakyReLU(0.1))curr_dim = curr_dim * 2if self.imsize == 64:layer4 = []layer4.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, 4, 2, 1)))layer4.append(nn.LeakyReLU(0.1))self.l4 = nn.Sequential(*layer4)curr_dim = curr_dim*2self.l1 = nn.Sequential(*layer1)self.l2 = nn.Sequential(*layer2)self.l3 = nn.Sequential(*layer3)last.append(nn.Conv2d(curr_dim, 1, 4))self.last = nn.Sequential(*last)self.attn1 = Self_Attn(256, 'relu')self.attn2 = Self_Attn(512, 'relu')def forward(self, x):out = self.l1(x)out = self.l2(out)out = self.l3(out)out,p1 = self.attn1(out)out=self.l4(out)out,p2 = self.attn2(out)out=self.last(out)return out.squeeze(), p1, p2

3. 模型训练

3.1 数据加载

本节中,我们将继续使用 Celeb A 人脸图像数据集构建 SAGAN

from torchvision import transforms
import torchvision.utils as vutils
import cv2, numpy as np
import torch
import os
from glob import glob
from PIL import Image
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from matplotlib import pyplot as plt
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"transform=transforms.Compose([transforms.Resize(64),transforms.CenterCrop(64),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])class Faces(Dataset):def __init__(self, folder):super().__init__()self.folder = folderself.images = sorted(glob(folder))def __len__(self):return len(self.images)def __getitem__(self, ix):image_path = self.images[ix]image = Image.open(image_path)image = transform(image)return imageds = Faces(folder='cropped_faces/*.jpg')
dataloader = DataLoader(ds, batch_size=64, shuffle=True, num_workers=8)

3.2 训练流程

使用标准的 GAN 训练步骤。损失函数使用铰链损失,使用 Adam 优化器,生成器 (1e-4) 和判别器 (4e-4) 使用不同的初始学习率:

class Trainer(object):def __init__(self, data_loader):# Data loaderself.data_loader = data_loader# exact and lossself.adv_loss = 'wgan-gp'# Model hyper-parametersself.imsize = 64self.g_num = 5self.z_dim = 128self.g_conv_dim = 64self.d_conv_dim = 64self.parallel = Falseself.lambda_gp = 10self.total_step = 50000self.d_iters = 5self.batch_size = 32self.num_workers = 2self.g_lr = 0.0001self.d_lr = 0.0004self.lr_decay = 0.95self.beta1 = 0.0self.beta2 = 0.9self.dataset = data_loaderself.sample_path = 'sagan_samples'self.sample_step = 100self.log_step = 10self.build_model()def train(self):# Data iteratordata_iter = iter(self.data_loader)step_per_epoch = len(self.data_loader)# Fixed input for debuggingfixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim))start = 0# Start timestart_time = time.time()for step in range(start, self.total_step):# ================== Train D ================== #self.D.train()self.G.train()try:real_images = next(data_iter)except:data_iter = iter(self.data_loader)real_images = next(data_iter)# Compute loss with real images# dr1, dr2, df1, df2, gf1, gf2 are attention scoresreal_images = tensor2var(real_images)d_out_real,dr1,dr2 = self.D(real_images)if self.adv_loss == 'wgan-gp':d_loss_real = - torch.mean(d_out_real)elif self.adv_loss == 'hinge':d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()# apply Gumbel Softmaxz = tensor2var(torch.randn(real_images.size(0), self.z_dim))fake_images,gf1,gf2 = self.G(z)d_out_fake,df1,df2 = self.D(fake_images)if self.adv_loss == 'wgan-gp':d_loss_fake = d_out_fake.mean()elif self.adv_loss == 'hinge':d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()# Backward + Optimized_loss = d_loss_real + d_loss_fakeself.reset_grad()d_loss.backward()self.d_optimizer.step()if self.adv_loss == 'wgan-gp':# Compute gradient penaltyalpha = torch.rand(real_images.size(0), 1, 1, 1).cuda().expand_as(real_images)interpolated = Variable(alpha * real_images.data + (1 - alpha) * fake_images.data, requires_grad=True)out,_,_ = self.D(interpolated)grad = torch.autograd.grad(outputs=out,inputs=interpolated,grad_outputs=torch.ones(out.size()).cuda(),retain_graph=True,create_graph=True,only_inputs=True)[0]grad = grad.view(grad.size(0), -1)grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)# Backward + Optimized_loss = self.lambda_gp * d_loss_gpself.reset_grad()d_loss.backward()self.d_optimizer.step()# ================== Train G and gumbel ================== ## Create random noisez = tensor2var(torch.randn(real_images.size(0), self.z_dim))fake_images,_,_ = self.G(z)# Compute loss with fake imagesg_out_fake,_,_ = self.D(fake_images)  # batch x nif self.adv_loss == 'wgan-gp':g_loss_fake = - g_out_fake.mean()elif self.adv_loss == 'hinge':g_loss_fake = - g_out_fake.mean()self.reset_grad()g_loss_fake.backward()self.g_optimizer.step()# Print out log infoif (step + 1) % self.log_step == 0:elapsed = time.time() - start_timeelapsed = str(datetime.timedelta(seconds=elapsed))print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, "" ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}".format(elapsed, step + 1, self.total_step, (step + 1),self.total_step , d_loss_real.item(),self.G.attn1.gamma.mean().item(), self.G.attn2.gamma.mean().item() ))# Sample imagesif (step + 1) % self.sample_step == 0:fake_images,_,_= self.G(fixed_z)save_image(denorm(fake_images.data),os.path.join(self.sample_path, '{}_fake.png'.format(step + 1)))def build_model(self):self.G = Generator(self.batch_size, self.imsize, self.z_dim, self.g_conv_dim).cuda()self.D = Discriminator(self.batch_size, self.imsize, self.d_conv_dim).cuda()if self.parallel:self.G = nn.DataParallel(self.G)self.D = nn.DataParallel(self.D)# Loss and optimizerself.g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.g_lr, [self.beta1, self.beta2])self.d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.d_lr, [self.beta1, self.beta2])self.c_loss = torch.nn.CrossEntropyLoss()# print networksprint(self.G)print(self.D)

模型训练完成后,使用生成器生成人脸图像:

生成结果

http://www.lryc.cn/news/584218.html

相关文章:

  • Java并发编程中的StampedLock详解:原理、实践与性能优化
  • UI前端大数据可视化实战策略:如何设计交互式数据探索界面?
  • Spring AI Alibaba(2)——通过Graph实现工作流
  • 异步I/O库:libuv、libev、libevent与libeio
  • Ubuntu基础(Python虚拟环境和Vue)
  • 输入框过滤选项列表,el-checkbox-group单选
  • 案例分享--福建洋柄水库大桥智慧桥梁安全监测(二)之数字孪生和系统平台
  • Qt开发:QtConcurrent介绍和使用
  • 【网络】Linux 内核优化实战 - net.ipv4.tcp_max_orphans
  • 如何发现Redis中的bigkey?
  • 数据库复合索引设计:为什么等值查询列应该放在范围查询列前面?
  • ip地址可以精确到什么级别?如何获取/更改ip地址
  • 第1讲:C语言常见概念
  • 实训八——路由器与交换机与网线
  • TCP传输控制层协议深入理解
  • 20250710【再来一题快慢指针】Leetcodehot100之141【首个自己一遍pass】今天计划
  • 【算法笔记】6.LeetCode-Hot100-链表专项
  • 数据跨越信任边界及修复方案
  • 通过vue如何利用 Three 绘制 简单3D模型(源码案例)
  • 观成科技:基于自监督学习技术的恶意加密流量检测方案
  • 科技守护银发睡眠健康:七彩喜睡眠监护仪重塑养老新体验
  • 医学+AI!湖北中医药大学信息工程学院与和鲸科技签约101数智领航计划
  • 图片合并pdf
  • MinerU将PDF转成md文件,并分拣图片
  • 【fitz+PIL】PDF图片文字颜色加深
  • 每日一SQL 【各赛事的用户注册率】
  • 基于Python的旅游推荐协同过滤算法系统(去哪儿网数据分析及可视化(Django+echarts))
  • 分布式ID方案
  • 数学建模-
  • ArcGIS 打开 nc 降雨量文件