生成模型实战 | GLOW详解与实现
生成模型实战 | GLOW详解与实现
- 0. 前言
- 1. 归一化流模型
- 1.1 归一化流与变换公式
- 1.2 RealNVP 的通道翻转
- 2. GLOW 架构
- 2.1 ActNorm
- 2.2 可逆 1×1 卷积
- 2.3 仿射耦合层
- 2.4 多尺度架构
- 3. 使用 PyTorch 实现 GLOW
- 3.1 数据处理
- 3.2 模型构建
- 3.3 模型训练
0. 前言
GLOW
(Generative Flow
) 是一种基于归一化流的生成模型,通过在每个流步骤中引入可逆的 1 × 1
卷积层,替代了 RealNVP 中通道翻转或固定置换的策略,从而使通道重排更具表达力,同时保持雅可比行列式和逆变换的高效计算能力。本文首先回顾归一化流与 RealNVP
的基本原理,接着剖析 GLOW
的四大核心模块:ActNorm
、可逆 1×1
卷积、仿射耦合层和多尺度架构,随后基于 PyTorch
实现 GLOW
模型,并在 CIFAR-10
数据集上进行训练。
1. 归一化流模型
1.1 归一化流与变换公式
在本节中,我们首先简要回顾归一化流模型的核心原理,归一化流利用可逆映射 fff 将简单分布 pZ(z)p_Z(z)pZ(z) 转换到样本分布 pX(x)p_X(x)pX(x),并通过以下变换公式实现实现精确对数似然计算和采样:
pX(x)=pZ(f(x))∣det(∂f(x)∂x)∣p_X(x)=p_Z(f(x)) |\text{det} (\frac {∂f(x)}{∂x})| pX(x)=pZ(f(x)) ∣det(∂x∂f(x))∣
1.2 RealNVP 的通道翻转
RealNVP
通过交替使用掩码耦合层 (masking coupling
) 和按通道翻转 (reverse channels
) 或固定置换,保证每个通道都能被多次变换。
2. GLOW 架构
2.1 ActNorm
ActNorm
是一种专为流模型设计的通道级归一化方法,于 GLOW
中首次提出。该层对输入激活 xxx 执行可学习仿射变换:
y=s⊙x+by=s⊙x+b y=s⊙x+b
其中 s,b∈RCs,b∈\mathbb R^Cs,b∈RC 分别为每个通道的尺度与偏移参数。这些参数在首次前向传播时通过对一个 minibatch
计算输出通道的均值 μμμ 和标准差 σσσ 进行数据依赖的初始化,使得初始化后的 yyy 满足 E[y]=0\mathbb E[y]=0E[y]=0 和 Var[y]=1Var[y]=1Var[y]=1。
与批归一化不同,ActNorm
仅在初始化时依赖 minibatch
,之后无需维护运行时统计量,从而提升了小批数据和解耦训练的稳定性。
import torch.nn as nn
import torchdef mean_dim(tensor, dim=None, keepdims=False):if dim is None:return tensor.mean()else:if isinstance(dim, int):dim = [dim]dim = sorted(dim)for d in dim:tensor = tensor.mean(dim=d, keepdim=True)if not keepdims:for i, d in enumerate(dim):tensor.squeeze_(d-i)return tensorclass ActNorm(nn.Module):def __init__(self, num_features, scale=1., return_ldj=False):super(ActNorm, self).__init__()self.register_buffer('is_initialized', torch.zeros(1))self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))self.logs = nn.Parameter(torch.zeros(1, num_features, 1, 1))self.num_features = num_featuresself.scale = float(scale)self.eps = 1e-6self.return_ldj = return_ldjdef initialize_parameters(self, x):if not self.training:returnwith torch.no_grad():bias = -mean_dim(x.clone(), dim=[0, 2, 3], keepdims=True)v = mean_dim((x.clone() + bias) ** 2, dim=[0, 2, 3], keepdims=True)logs = (self.scale / (v.sqrt() + self.eps)).log()self.bias.data.copy_(bias.data)self.logs.data.copy_(logs.data)self.is_initialized += 1.def _center(self, x, reverse=False):if reverse:return x - self.biaselse:return x + self.biasdef _scale(self, x, sldj, reverse=False):logs = self.logsif reverse:x = x * logs.mul(-1).exp()else:x = x * logs.exp()if sldj is not None:ldj = logs.sum() * x.size(2) * x.size(3)if reverse:sldj = sldj - ldjelse:sldj = sldj + ldjreturn x, sldjdef forward(self, x, ldj=None, reverse=False):if not self.is_initialized:self.initialize_parameters(x)if reverse:x, ldj = self._scale(x, ldj, reverse)x = self._center(x, reverse)else:x = self._center(x, reverse)x, ldj = self._scale(x, ldj, reverse)if self.return_ldj:return x, ldjreturn x
2.2 可逆 1×1 卷积
GLOW
中的可逆 1×1
卷积用一个 C×CC×CC×C 的可学习矩阵 WWW 取代了 RealNVP
中的固定通道翻转或置换操作。在空间位置 (i,j)(i,j)(i,j) 上,其映射可写为:
yi,j=Wxi,jy_{i,j}=W x_{i,j} yi,j=W xi,j
对应的对数行列式为:
log det ∣∂y∂x∣=H×W×log∣detW∣\text {log}\ \text{det} |\frac{∂y}{∂x}|=H×W×\text {log}|\text{det}W| log det ∣∂x∂y∣=H×W×log∣detW∣
其中 H,WH,WH,W 分别为空间高宽。
为了进一步加速行列式与逆矩阵的计算,通常将 WWW 参数化为 LU
分解形式,即 W=PLUW=PLUW=PLU,只需学习下三角矩阵 LLL 和上三角矩阵 UUU 的非对角元素,行列式则为 ∏iUii∏_iU_{ii}∏iUii。
通过这种可学习的通道重排,模型能够自动挖掘最优的特征混合方式,从而在生成质量与训练效率上均取得显著提升。
import numpy as npclass InvConv(nn.Module):def __init__(self, num_channels):super(InvConv, self).__init__()self.num_channels = num_channels# Initialize with a random orthogonal matrixw_init = np.random.randn(num_channels, num_channels)w_init = np.linalg.qr(w_init)[0].astype(np.float32)self.weight = nn.Parameter(torch.from_numpy(w_init))def forward(self, x, sldj, reverse=False):ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)if reverse:weight = torch.inverse(self.weight.double()).float()sldj = sldj - ldjelse:weight = self.weightsldj = sldj + ldjweight = weight.view(self.num_channels, self.num_channels, 1, 1)z = F.conv2d(x, weight)return z, sldj
2.3 仿射耦合层
仿射耦合层最早在 RealNVP
中提出,是 GLOW
中不可或缺的组成部分。该层将输入 x∈RC×H×Wx∈\mathbb R^{C×H×W}x∈RC×H×W 沿通道维度划分为两部分 (xa,xb)(x_a,x_b)(xa,xb),并通过神经网络生成尺度和平移参数保证了整个变换的可逆性:
(s,t)=NN(xb),ya=s(xb)⊙xa+t(xb),yb=xb(s,t)=NN(x_b),ya=s(x_b)⊙x_a+t(x_b),y_b=x_b (s,t)=NN(xb),ya=s(xb)⊙xa+t(xb),yb=xb
其对数雅可比行列式可高效地计算为:
∑h,w,c∈alogsc(xb[h,w])∑_{h,w, c∈a}\text {log}s_c(x_b[h,w]) h,w, c∈a∑logsc(xb[h,w])
仅与输出尺度参数 sss 的元素相加相关,计算复杂度随输入维度线性增长。
import torch.nn.functional as Fclass Coupling(nn.Module):def __init__(self, in_channels, mid_channels):super(Coupling, self).__init__()self.nn = NN(in_channels, mid_channels, 2 * in_channels)self.scale = nn.Parameter(torch.ones(in_channels, 1, 1))def forward(self, x, ldj, reverse=False):x_change, x_id = x.chunk(2, dim=1)st = self.nn(x_id)s, t = st[:, 0::2, ...], st[:, 1::2, ...]s = self.scale * torch.tanh(s)# Scale and translateif reverse:x_change = x_change * s.mul(-1).exp() - tldj = ldj - s.flatten(1).sum(-1)else:x_change = (x_change + t) * s.exp()ldj = ldj + s.flatten(1).sum(-1)x = torch.cat((x_change, x_id), dim=1)return x, ldjclass NN(nn.Module):def __init__(self, in_channels, mid_channels, out_channels,use_act_norm=False):super(NN, self).__init__()norm_fn = ActNorm if use_act_norm else nn.BatchNorm2dself.in_norm = norm_fn(in_channels)self.in_conv = nn.Conv2d(in_channels, mid_channels,kernel_size=3, padding=1, bias=False)nn.init.normal_(self.in_conv.weight, 0., 0.05)self.mid_norm = norm_fn(mid_channels)self.mid_conv = nn.Conv2d(mid_channels, mid_channels,kernel_size=1, padding=0, bias=False)nn.init.normal_(self.mid_conv.weight, 0., 0.05)self.out_norm = norm_fn(mid_channels)self.out_conv = nn.Conv2d(mid_channels, out_channels,kernel_size=3, padding=1, bias=True)nn.init.zeros_(self.out_conv.weight)nn.init.zeros_(self.out_conv.bias)def forward(self, x):x = self.in_norm(x)x = F.relu(x)x = self.in_conv(x)x = self.mid_norm(x)x = F.relu(x)x = self.mid_conv(x)x = self.out_norm(x)x = F.relu(x)x = self.out_conv(x)return x
2.4 多尺度架构
GLOW
延续了 RealNVP
的多尺度架构思想,通过分层的流步骤和因子化操作将中间表示逐级分解。整体模型由 LLL 个尺度 (level
) 组成,每个尺度内部包含 KKK 次完整的流步骤 (step
),每步依次执行 ActNorm
、可逆 1×1
卷积和仿射耦合层。
在每个尺度结束时,先通过 squeeze
操作将特征图空间大小减少至原像素的四分之一(同时通道数扩大四倍),然后使用 split
操作将部分通道因子化为潜变量 zzz,余下通道继续进入下一级流。
这种多尺度分解在保持对数似然精度的同时,有效降低了计算与存储开销,并在不同尺度上捕捉图像的全局与局部结构信息。
3. 使用 PyTorch 实现 GLOW
在本节中,使用 PyTorch
实现 GLOW
,并在 CIFAR-10
数据集上进行训练。
3.1 数据处理
用 torchvision.transforms
对 CIFAR-10
图像进行处理:
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor()])transform_test = transforms.Compose([transforms.ToTensor()
])trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train)
trainloader = data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
testloader = data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
3.2 模型构建
基于 ActNorm
、可逆 1×1
卷积、仿射耦合层和多尺度架构实现 GLOW
模型:
class Glow(nn.Module):def __init__(self, num_channels, num_levels, num_steps):super(Glow, self).__init__()# Use bounds to rescale images before converting to logits, not learnedself.register_buffer('bounds', torch.tensor([0.9], dtype=torch.float32))self.flows = _Glow(in_channels=4 * 3, # RGB image after squeezemid_channels=num_channels,num_levels=num_levels,num_steps=num_steps)def forward(self, x, reverse=False):if reverse:sldj = torch.zeros(x.size(0), device=x.device)else:# Expect inputs in [0, 1]if x.min() < 0 or x.max() > 1:raise ValueError('Expected x in [0, 1], got min/max {}/{}'.format(x.min(), x.max()))# De-quantize and convert to logitsx, sldj = self._pre_process(x)x = squeeze(x)x, sldj = self.flows(x, sldj, reverse)x = squeeze(x, reverse=True)return x, sldjdef _pre_process(self, x):y = (x * 255. + torch.rand_like(x)) / 256.y = (2 * y - 1) * self.boundsy = (y + 1) / 2y = y.log() - (1. - y).log()# Save log-determinant of Jacobian of initial transformldj = F.softplus(y) + F.softplus(-y) \- F.softplus((1. - self.bounds).log() - self.bounds.log())sldj = ldj.flatten(1).sum(-1)return y, sldjclass _Glow(nn.Module):def __init__(self, in_channels, mid_channels, num_levels, num_steps):super(_Glow, self).__init__()self.steps = nn.ModuleList([_FlowStep(in_channels=in_channels,mid_channels=mid_channels)for _ in range(num_steps)])if num_levels > 1:self.next = _Glow(in_channels=2 * in_channels,mid_channels=mid_channels,num_levels=num_levels - 1,num_steps=num_steps)else:self.next = Nonedef forward(self, x, sldj, reverse=False):if not reverse:for step in self.steps:x, sldj = step(x, sldj, reverse)if self.next is not None:x = squeeze(x)x, x_split = x.chunk(2, dim=1)x, sldj = self.next(x, sldj, reverse)x = torch.cat((x, x_split), dim=1)x = squeeze(x, reverse=True)if reverse:for step in reversed(self.steps):x, sldj = step(x, sldj, reverse)return x, sldjclass _FlowStep(nn.Module):def __init__(self, in_channels, mid_channels):super(_FlowStep, self).__init__()# Activation normalization, invertible 1x1 convolution, affine couplingself.norm = ActNorm(in_channels, return_ldj=True)self.conv = InvConv(in_channels)self.coup = Coupling(in_channels // 2, mid_channels)def forward(self, x, sldj=None, reverse=False):if reverse:x, sldj = self.coup(x, sldj, reverse)x, sldj = self.conv(x, sldj, reverse)x, sldj = self.norm(x, sldj, reverse)else:x, sldj = self.norm(x, sldj, reverse)x, sldj = self.conv(x, sldj, reverse)x, sldj = self.coup(x, sldj, reverse)return x, sldjdef squeeze(x, reverse=False):b, c, h, w = x.size()if reverse:# Unsqueezex = x.view(b, c // 4, 2, 2, h, w)x = x.permute(0, 1, 4, 2, 5, 3).contiguous()x = x.view(b, c // 4, h * 2, w * 2)else:# Squeezex = x.view(b, c, h // 2, 2, w // 2, 2)x = x.permute(0, 1, 3, 5, 2, 4).contiguous()x = x.view(b, c * 2 * 2, h // 2, w // 2)return x
3.3 模型训练
实例化模型、损失函数和优化器,并进行训练:
net = Glow(num_channels=num_channels,num_levels=num_levels,num_steps=num_steps)
net = net.to(device)loss_fn = NLLLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=lr)
scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / warm_up))@torch.enable_grad()
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm):global global_stepprint('\nEpoch: %d' % epoch)net.train()loss_meter = AverageMeter()with tqdm(total=len(trainloader.dataset)) as progress_bar:for x, _ in trainloader:x = x.to(device)optimizer.zero_grad()z, sldj = net(x, reverse=False)loss = loss_fn(z, sldj)loss_meter.update(loss.item(), x.size(0))loss.backward()if max_grad_norm > 0:clip_grad_norm(optimizer, max_grad_norm)optimizer.step()scheduler.step(global_step)progress_bar.set_postfix(nll=loss_meter.avg,bpd=bits_per_dim(x, loss_meter.avg),lr=optimizer.param_groups[0]['lr'])progress_bar.update(x.size(0))global_step += x.size(0)@torch.no_grad()
def sample(net, batch_size, device):z = torch.randn((batch_size, 3, 32, 32), dtype=torch.float32, device=device)x, _ = net(z, reverse=True)x = torch.sigmoid(x)return x@torch.no_grad()
def test(epoch, net, testloader, device, loss_fn, num_samples):global best_lossnet.eval()loss_meter = AverageMeter()with tqdm(total=len(testloader.dataset)) as progress_bar:for x, _ in testloader:x = x.to(device)z, sldj = net(x, reverse=False)loss = loss_fn(z, sldj)loss_meter.update(loss.item(), x.size(0))progress_bar.set_postfix(nll=loss_meter.avg,bpd=bits_per_dim(x, loss_meter.avg))progress_bar.update(x.size(0))# Save checkpointif loss_meter.avg < best_loss:print('Saving...')state = {'net': net.state_dict(),'test_loss': loss_meter.avg,'epoch': epoch,}os.makedirs('ckpts', exist_ok=True)torch.save(state, 'ckpts/best.pth.tar')best_loss = loss_meter.avg# Save samples and dataimages = sample(net, num_samples, device)os.makedirs('samples', exist_ok=True)images_concat = torchvision.utils.make_grid(images, nrow=int(num_samples ** 0.5), padding=2, pad_value=255)torchvision.utils.save_image(images_concat, 'samples/epoch_{}.png'.format(epoch))start_epoch = 0
for epoch in range(start_epoch, start_epoch + num_epochs):train(epoch, net, trainloader, device, optimizer, scheduler,loss_fn, max_grad_norm)test(epoch, net, testloader, device, loss_fn, num_samples)
在 100
个 epoch
后,模型可生成逼真度较高的 32 × 32
彩色图像,样本在多通道细节和整体结构上均有良好效果,下图展示了训练过程中,不同 epoch
生成的图像对比: