当前位置: 首页 > news >正文

生成模型实战 | GLOW详解与实现

生成模型实战 | GLOW详解与实现

    • 0. 前言
    • 1. 归一化流模型
      • 1.1 归一化流与变换公式
      • 1.2 RealNVP 的通道翻转
    • 2. GLOW 架构
      • 2.1 ActNorm
      • 2.2 可逆 1×1 卷积
      • 2.3 仿射耦合层
      • 2.4 多尺度架构
    • 3. 使用 PyTorch 实现 GLOW
      • 3.1 数据处理
      • 3.2 模型构建
      • 3.3 模型训练

0. 前言

GLOW (Generative Flow) 是一种基于归一化流的生成模型,通过在每个流步骤中引入可逆的 1 × 1 卷积层,替代了 RealNVP 中通道翻转或固定置换的策略,从而使通道重排更具表达力,同时保持雅可比行列式和逆变换的高效计算能力。本文首先回顾归一化流与 RealNVP 的基本原理,接着剖析 GLOW 的四大核心模块:ActNorm、可逆 1×1 卷积、仿射耦合层和多尺度架构,随后基于 PyTorch 实现 GLOW 模型,并在 CIFAR-10 数据集上进行训练。

1. 归一化流模型

1.1 归一化流与变换公式

在本节中,我们首先简要回顾归一化流模型的核心原理,归一化流利用可逆映射 fff 将简单分布 pZ(z)p_Z(z)pZ(z) 转换到样本分布 pX​(x)p_X​(x)pX(x),并通过以下变换公式实现实现精确对数似然计算和采样:
pX(x)=pZ(f(x))∣⁡det⁡(⁡∂f(x)∂x)∣​p_X(x)=p_Z(f(x)) |\text{⁡det}⁡ (⁡\frac {∂f(x)}{∂x})| ​pX(x)=pZ(f(x)) ∣⁡det(xf(x))∣​

1.2 RealNVP 的通道翻转

RealNVP 通过交替使用掩码耦合层 (masking coupling) 和按通道翻转 (reverse channels) 或固定置换,保证每个通道都能被多次变换。

2. GLOW 架构

2.1 ActNorm

ActNorm 是一种专为流模型设计的通道级归一化方法,于 GLOW 中首次提出。该层对输入激活 xxx 执行可学习仿射变换:
y=s⊙x+by=s⊙x+b y=sx+b
其中 s,b∈RCs,b∈\mathbb R^Cs,bRC 分别为每个通道的尺度与偏移参数。这些参数在首次前向传播时通过对一个 minibatch 计算输出通道的均值 μμμ 和标准差 σσσ 进行数据依赖的初始化,使得初始化后的 yyy 满足 E[y]=0\mathbb E[y]=0E[y]=0Var[y]=1Var[y]=1Var[y]=1
与批归一化不同,ActNorm 仅在初始化时依赖 minibatch,之后无需维护运行时统计量,从而提升了小批数据和解耦训练的稳定性。

import torch.nn as nn
import torchdef mean_dim(tensor, dim=None, keepdims=False):if dim is None:return tensor.mean()else:if isinstance(dim, int):dim = [dim]dim = sorted(dim)for d in dim:tensor = tensor.mean(dim=d, keepdim=True)if not keepdims:for i, d in enumerate(dim):tensor.squeeze_(d-i)return tensorclass ActNorm(nn.Module):def __init__(self, num_features, scale=1., return_ldj=False):super(ActNorm, self).__init__()self.register_buffer('is_initialized', torch.zeros(1))self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))self.logs = nn.Parameter(torch.zeros(1, num_features, 1, 1))self.num_features = num_featuresself.scale = float(scale)self.eps = 1e-6self.return_ldj = return_ldjdef initialize_parameters(self, x):if not self.training:returnwith torch.no_grad():bias = -mean_dim(x.clone(), dim=[0, 2, 3], keepdims=True)v = mean_dim((x.clone() + bias) ** 2, dim=[0, 2, 3], keepdims=True)logs = (self.scale / (v.sqrt() + self.eps)).log()self.bias.data.copy_(bias.data)self.logs.data.copy_(logs.data)self.is_initialized += 1.def _center(self, x, reverse=False):if reverse:return x - self.biaselse:return x + self.biasdef _scale(self, x, sldj, reverse=False):logs = self.logsif reverse:x = x * logs.mul(-1).exp()else:x = x * logs.exp()if sldj is not None:ldj = logs.sum() * x.size(2) * x.size(3)if reverse:sldj = sldj - ldjelse:sldj = sldj + ldjreturn x, sldjdef forward(self, x, ldj=None, reverse=False):if not self.is_initialized:self.initialize_parameters(x)if reverse:x, ldj = self._scale(x, ldj, reverse)x = self._center(x, reverse)else:x = self._center(x, reverse)x, ldj = self._scale(x, ldj, reverse)if self.return_ldj:return x, ldjreturn x

2.2 可逆 1×1 卷积

GLOW 中的可逆 1×1 卷积用一个 C×CC×CC×C 的可学习矩阵 WWW 取代了 RealNVP 中的固定通道翻转或置换操作。在空间位置 (i,j)(i,j)(i,j) 上,其映射可写为:
yi,j=Wxi,jy_{i,j}=W x_{i,j} yi,j=Wxi,j
对应的对数行列式为:
log ⁡det⁡ ⁣∣∂y∂x∣=H×W×log∣⁡detW∣\text {log}\ \text{⁡det}⁡ ⁣|\frac{∂y}{∂x}|=H×W×\text {log}|\text{⁡det}W| log ⁡det⁡ ⁣∣xy=H×W×log⁡detW
其中 H,WH,WH,W 分别为空间高宽。
为了进一步加速行列式与逆矩阵的计算,通常将 WWW 参数化为 LU 分解形式,即 W=PLUW=PLUW=PLU,只需学习下三角矩阵 LLL 和上三角矩阵 UUU 的非对角元素,行列式则为 ∏iUii∏_iU_{ii}iUii
通过这种可学习的通道重排,模型能够自动挖掘最优的特征混合方式,从而在生成质量与训练效率上均取得显著提升。

import numpy as npclass InvConv(nn.Module):def __init__(self, num_channels):super(InvConv, self).__init__()self.num_channels = num_channels# Initialize with a random orthogonal matrixw_init = np.random.randn(num_channels, num_channels)w_init = np.linalg.qr(w_init)[0].astype(np.float32)self.weight = nn.Parameter(torch.from_numpy(w_init))def forward(self, x, sldj, reverse=False):ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)if reverse:weight = torch.inverse(self.weight.double()).float()sldj = sldj - ldjelse:weight = self.weightsldj = sldj + ldjweight = weight.view(self.num_channels, self.num_channels, 1, 1)z = F.conv2d(x, weight)return z, sldj

2.3 仿射耦合层

仿射耦合层最早在 RealNVP 中提出,是 GLOW 中不可或缺的组成部分。该层将输入 x∈RC×H×Wx∈\mathbb R^{C×H×W}xRC×H×W 沿通道维度划分为两部分 (xa,xb)(x_a,x_b)(xa,xb),并通过神经网络生成尺度和平移参数保证了整个变换的可逆性:
(s,t)=NN(xb),ya=s(xb)⊙xa+t(xb),yb=xb(s,t)=NN(x_b),ya=s(x_b)⊙x_a+t(x_b),y_b=x_b (s,t)=NN(xb),ya=s(xb)xa+t(xb),yb=xb

其对数雅可比行列式可高效地计算为:
∑h,w,c∈alog⁡sc(xb[h,w])∑_{h,w,  c∈a}\text {log}⁡s_c(x_b[h,w]) h,w,  calogsc(xb[h,w])
仅与输出尺度参数 sss 的元素相加相关,计算复杂度随输入维度线性增长。

import torch.nn.functional as Fclass Coupling(nn.Module):def __init__(self, in_channels, mid_channels):super(Coupling, self).__init__()self.nn = NN(in_channels, mid_channels, 2 * in_channels)self.scale = nn.Parameter(torch.ones(in_channels, 1, 1))def forward(self, x, ldj, reverse=False):x_change, x_id = x.chunk(2, dim=1)st = self.nn(x_id)s, t = st[:, 0::2, ...], st[:, 1::2, ...]s = self.scale * torch.tanh(s)# Scale and translateif reverse:x_change = x_change * s.mul(-1).exp() - tldj = ldj - s.flatten(1).sum(-1)else:x_change = (x_change + t) * s.exp()ldj = ldj + s.flatten(1).sum(-1)x = torch.cat((x_change, x_id), dim=1)return x, ldjclass NN(nn.Module):def __init__(self, in_channels, mid_channels, out_channels,use_act_norm=False):super(NN, self).__init__()norm_fn = ActNorm if use_act_norm else nn.BatchNorm2dself.in_norm = norm_fn(in_channels)self.in_conv = nn.Conv2d(in_channels, mid_channels,kernel_size=3, padding=1, bias=False)nn.init.normal_(self.in_conv.weight, 0., 0.05)self.mid_norm = norm_fn(mid_channels)self.mid_conv = nn.Conv2d(mid_channels, mid_channels,kernel_size=1, padding=0, bias=False)nn.init.normal_(self.mid_conv.weight, 0., 0.05)self.out_norm = norm_fn(mid_channels)self.out_conv = nn.Conv2d(mid_channels, out_channels,kernel_size=3, padding=1, bias=True)nn.init.zeros_(self.out_conv.weight)nn.init.zeros_(self.out_conv.bias)def forward(self, x):x = self.in_norm(x)x = F.relu(x)x = self.in_conv(x)x = self.mid_norm(x)x = F.relu(x)x = self.mid_conv(x)x = self.out_norm(x)x = F.relu(x)x = self.out_conv(x)return x

2.4 多尺度架构

GLOW 延续了 RealNVP 的多尺度架构思想,通过分层的流步骤和因子化操作将中间表示逐级分解。整体模型由 LLL 个尺度 (level) 组成,每个尺度内部包含 KKK 次完整的流步骤 (step),每步依次执行 ActNorm、可逆 1×1 卷积和仿射耦合层。
在每个尺度结束时,先通过 squeeze 操作将特征图空间大小减少至原像素的四分之一(同时通道数扩大四倍),然后使用 split 操作将部分通道因子化为潜变量 zzz,余下通道继续进入下一级流。
这种多尺度分解在保持对数似然精度的同时,有效降低了计算与存储开销,并在不同尺度上捕捉图像的全局与局部结构信息。

3. 使用 PyTorch 实现 GLOW

在本节中,使用 PyTorch 实现 GLOW,并在 CIFAR-10 数据集上进行训练。

3.1 数据处理

torchvision.transformsCIFAR-10 图像进行处理:

transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor()])transform_test = transforms.Compose([transforms.ToTensor()
])trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train)
trainloader = data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
testloader = data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

3.2 模型构建

基于 ActNorm、可逆 1×1 卷积、仿射耦合层和多尺度架构实现 GLOW 模型:

class Glow(nn.Module):def __init__(self, num_channels, num_levels, num_steps):super(Glow, self).__init__()# Use bounds to rescale images before converting to logits, not learnedself.register_buffer('bounds', torch.tensor([0.9], dtype=torch.float32))self.flows = _Glow(in_channels=4 * 3,  # RGB image after squeezemid_channels=num_channels,num_levels=num_levels,num_steps=num_steps)def forward(self, x, reverse=False):if reverse:sldj = torch.zeros(x.size(0), device=x.device)else:# Expect inputs in [0, 1]if x.min() < 0 or x.max() > 1:raise ValueError('Expected x in [0, 1], got min/max {}/{}'.format(x.min(), x.max()))# De-quantize and convert to logitsx, sldj = self._pre_process(x)x = squeeze(x)x, sldj = self.flows(x, sldj, reverse)x = squeeze(x, reverse=True)return x, sldjdef _pre_process(self, x):y = (x * 255. + torch.rand_like(x)) / 256.y = (2 * y - 1) * self.boundsy = (y + 1) / 2y = y.log() - (1. - y).log()# Save log-determinant of Jacobian of initial transformldj = F.softplus(y) + F.softplus(-y) \- F.softplus((1. - self.bounds).log() - self.bounds.log())sldj = ldj.flatten(1).sum(-1)return y, sldjclass _Glow(nn.Module):def __init__(self, in_channels, mid_channels, num_levels, num_steps):super(_Glow, self).__init__()self.steps = nn.ModuleList([_FlowStep(in_channels=in_channels,mid_channels=mid_channels)for _ in range(num_steps)])if num_levels > 1:self.next = _Glow(in_channels=2 * in_channels,mid_channels=mid_channels,num_levels=num_levels - 1,num_steps=num_steps)else:self.next = Nonedef forward(self, x, sldj, reverse=False):if not reverse:for step in self.steps:x, sldj = step(x, sldj, reverse)if self.next is not None:x = squeeze(x)x, x_split = x.chunk(2, dim=1)x, sldj = self.next(x, sldj, reverse)x = torch.cat((x, x_split), dim=1)x = squeeze(x, reverse=True)if reverse:for step in reversed(self.steps):x, sldj = step(x, sldj, reverse)return x, sldjclass _FlowStep(nn.Module):def __init__(self, in_channels, mid_channels):super(_FlowStep, self).__init__()# Activation normalization, invertible 1x1 convolution, affine couplingself.norm = ActNorm(in_channels, return_ldj=True)self.conv = InvConv(in_channels)self.coup = Coupling(in_channels // 2, mid_channels)def forward(self, x, sldj=None, reverse=False):if reverse:x, sldj = self.coup(x, sldj, reverse)x, sldj = self.conv(x, sldj, reverse)x, sldj = self.norm(x, sldj, reverse)else:x, sldj = self.norm(x, sldj, reverse)x, sldj = self.conv(x, sldj, reverse)x, sldj = self.coup(x, sldj, reverse)return x, sldjdef squeeze(x, reverse=False):b, c, h, w = x.size()if reverse:# Unsqueezex = x.view(b, c // 4, 2, 2, h, w)x = x.permute(0, 1, 4, 2, 5, 3).contiguous()x = x.view(b, c // 4, h * 2, w * 2)else:# Squeezex = x.view(b, c, h // 2, 2, w // 2, 2)x = x.permute(0, 1, 3, 5, 2, 4).contiguous()x = x.view(b, c * 2 * 2, h // 2, w // 2)return x

3.3 模型训练

实例化模型、损失函数和优化器,并进行训练:

net = Glow(num_channels=num_channels,num_levels=num_levels,num_steps=num_steps)
net = net.to(device)loss_fn = NLLLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=lr)
scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / warm_up))@torch.enable_grad()
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm):global global_stepprint('\nEpoch: %d' % epoch)net.train()loss_meter = AverageMeter()with tqdm(total=len(trainloader.dataset)) as progress_bar:for x, _ in trainloader:x = x.to(device)optimizer.zero_grad()z, sldj = net(x, reverse=False)loss = loss_fn(z, sldj)loss_meter.update(loss.item(), x.size(0))loss.backward()if max_grad_norm > 0:clip_grad_norm(optimizer, max_grad_norm)optimizer.step()scheduler.step(global_step)progress_bar.set_postfix(nll=loss_meter.avg,bpd=bits_per_dim(x, loss_meter.avg),lr=optimizer.param_groups[0]['lr'])progress_bar.update(x.size(0))global_step += x.size(0)@torch.no_grad()
def sample(net, batch_size, device):z = torch.randn((batch_size, 3, 32, 32), dtype=torch.float32, device=device)x, _ = net(z, reverse=True)x = torch.sigmoid(x)return x@torch.no_grad()
def test(epoch, net, testloader, device, loss_fn, num_samples):global best_lossnet.eval()loss_meter = AverageMeter()with tqdm(total=len(testloader.dataset)) as progress_bar:for x, _ in testloader:x = x.to(device)z, sldj = net(x, reverse=False)loss = loss_fn(z, sldj)loss_meter.update(loss.item(), x.size(0))progress_bar.set_postfix(nll=loss_meter.avg,bpd=bits_per_dim(x, loss_meter.avg))progress_bar.update(x.size(0))# Save checkpointif loss_meter.avg < best_loss:print('Saving...')state = {'net': net.state_dict(),'test_loss': loss_meter.avg,'epoch': epoch,}os.makedirs('ckpts', exist_ok=True)torch.save(state, 'ckpts/best.pth.tar')best_loss = loss_meter.avg# Save samples and dataimages = sample(net, num_samples, device)os.makedirs('samples', exist_ok=True)images_concat = torchvision.utils.make_grid(images, nrow=int(num_samples ** 0.5), padding=2, pad_value=255)torchvision.utils.save_image(images_concat, 'samples/epoch_{}.png'.format(epoch))start_epoch = 0
for epoch in range(start_epoch, start_epoch + num_epochs):train(epoch, net, trainloader, device, optimizer, scheduler,loss_fn, max_grad_norm)test(epoch, net, testloader, device, loss_fn, num_samples)

100epoch 后,模型可生成逼真度较高的 32 × 32 彩色图像,样本在多通道细节和整体结构上均有良好效果,下图展示了训练过程中,不同 epoch 生成的图像对比:

生成结果

http://www.lryc.cn/news/604806.html

相关文章:

  • JavaFX CSS @font-face 错误全面分析 loadStylesheetUnPrivileged / reportException
  • 快速删除Word和WPS文字中的空白行
  • Redis实现数据传输简介
  • Kubernetes高级调度02
  • Elasticsearch 索引管理 API 实战:涵盖创建、查看、修改、删除及别名
  • Redis 面试全解析:从数据结构到集群架构(含实战解决方案)
  • 设计模式之单例模式及其在多线程下的使用
  • 【C#】DevExpress.XtraEditors.MemoEdit memoEditLog控件讲解
  • Rabbitmq中常见7种模式介绍
  • pytorch小记(三十三):PyTorch 使用 TensorBoard 可视化训练过程(含完整示例)
  • 用 Go Typed Client 快速上手 Elasticsearch —— 从建索引到聚合的完整实战
  • 8.Linux : 日志的管理与时钟同步的配置
  • Rabbit MQ的消息模式-Java原生代码
  • YOLO-01目标检测基础
  • 02 基于sklearn的机械学习-特征降维(特征选择、PCA)、KNN算法、模型选择与调优(交叉验证、朴素贝叶斯算法、拉普拉斯平滑)
  • Android调用python库和方法的实现
  • YOLOv5u:无锚点检测的革命性进步
  • android-PMS-创建新用户流程
  • 舆情监测专员需要哪些常用软件工具?
  • 基于 Hadoop 生态圈的数据仓库实践 —— OLAP 与数据可视化(一)
  • 论文Review 3DGSSLAM S3PO-GS | ICCV 2025 港科广出品!| 高效快速的3DGSSLAM!
  • sqli-labs:Less-1关卡详细解析
  • CMS框架漏洞
  • 3D Web轻量化引擎HOOPS Communicator数据处理与流式加载能力概述
  • 【音视频】WebRTC-Web 音视频采集与播放
  • 【预判一手面试问题:排序】
  • 依托客户满意度分析协助企业精准把握市场趋势​(满意度调查)
  • 智能AI医疗物资/耗材管理系统升级改造方案分析
  • InfluxDB 与 Java 框架集成:Spring Boot 实战(二)
  • VSCode插件开发完整教程:从零开始创建文件导出插件