当前位置: 首页 > news >正文

CFG的前世今生

文章目录

  • 简介
  • CG到CFG
    • Classifier Guidance
    • Classifier-Free Guidance

简介

DDPM将扩散模型在图片生成任务中做work后,大量研究人员开始对其进行迭代。虽然DDPM论文证明了扩散模型在图片生成任务中的潜力,但是其整体性能,特别是“有条件生成”,相较于当时的GAN系列模型还是存在差距,直到Openai的Diffusion Models Beat GANs on Image Synthesis这篇论文出现,扩散模型在有条件图片生成任务上超过了GANs,而CLASSIFIER-FREE DIFFUSION GUIDANCE这篇论文对上篇论文中的核心思想进行优化,提高模型的性能和计算效率,该方法就是目前在扩散模型生成领域广泛使用的CFG。

CG到CFG

Classifier Guidance

Diffusion Models Beat GANs on Image Synthesis论文中主要从两方面提高扩散模型性能,一方面是通过一系列消融实验筛选出更优的模型架构,提升扩散模型的无条件图片生成能力,此方面不是文本的重点;另一方面就是采用分类器引导(Classifier Guidance, CG),利用分类器梯度指导生成过程,大幅度提高了扩散模型在有条件图片生成任务中的多样性和保真性,使得扩散模型整体性能超过GANs。

CG的主要思想是使用一个分类器的梯度去指导预训练的扩散模型进行有条件生成。在图片生成任务中,就是基于带噪声信息的图片 x t x_t xt训练一个分类器 p ϕ ( y ∣ x t , t ) p_{\phi}(y|x_t,t) pϕ(yxt,t),然后使用该分类器的梯度 ∇ x t log ⁡ p ϕ ( y ∣ x t , t ) \nabla_{x_t} \log p_{\phi}(y|x_t,t) xtlogpϕ(yxt,t)引导扩散模型采样过程朝向任意类别标签 y y y。主要注意的是,分类器不仅要在原始干净的图像数据上训练,还需要在与扩散过程中噪声水平一致的加噪图片数据上训练,因为分类器会在扩散模型所有采样时间步上应用,其需要具备从含噪输入 x t x_t xt中识别列别 y y y的能力。

根据贝叶斯定律,条件概率 p ( x t ∣ y ) = p ( x t , y ) p ( y ) = p ( y ∣ x t ) p ( x t ) p ( y ) p(x_t|y)=\frac{p(x_t,y)}{p(y)}=\frac{p(y|x_t)p(x_t)}{p(y)} p(xty)=p(y)p(xt,y)=p(y)p(yxt)p(xt),两边取对数后求关于 x t x_t xt的梯度(对应于训练过程中的反向传播)得到 ∇ x t log ⁡ p ( x t ∣ y ) = ∇ x t log ⁡ p ( y ∣ x t ) + ∇ x t log ⁡ p ( x t ) \nabla_{x_t} \log p(x_t|y)= \nabla_{x_t} \log p(y|x_t)+\nabla_{x_t} \log p(x_t) xtlogp(xty)=xtlogp(yxt)+xtlogp(xt) p ( y ) p(y) p(y)这一样忽略是因为 y y y表示已知的类别等条件信息,与 x t x_t xt无关,倒数为0。其中 ∇ x t log ⁡ p ( x t ) \nabla_{x_t} \log p(x_t) xtlogp(xt)是扩散模型的边际分布梯度,表示模型学习的数据分布,而 ∇ x t log ⁡ p ( y ∣ x t ) \nabla_{x_t} \log p(y|x_t) xtlogp(yxt)是分类器梯度,表示当输入为含噪样本 x t x_t xt时,使 x t x_t xt更接近类别 y y y的方向(即让分类器对 y y y的预测概率增加的方向)。直观感受上,分类器训练过程中学习了类别 y y y的特征表示,对含噪声样本 的梯度 ∇ x t log ⁡ p ( y ∣ x t ) \nabla_{x_t} \log p(y|x_t) xtlogp(yxt)相当于找到了“如何调整 x t x_t xt使其更符合 y y y的特征”的方向。综上,使用分类器引导可以强制样本符合目标类别的特征,提升生图图像与类别标签的一致性,但是会一定程度地价格低原扩散模型的随机性,即降低模型多样性。

如果对常规扩散模型(如DDPM)了解的读者应该都直达扩散模型就是在学习 p θ ( x t ∣ x t + 1 ) p_{\theta}(x_t|x_{t+1}) pθ(xtxt+1),即从当前步预测前一步样本,此过程被限制为符合高斯分布,即 p θ ( x t ∣ x t + 1 ) = N ( μ , Σ ) p_{\theta}(x_t|x_{t+1})=\mathcal{N}(\mu,\Sigma) pθ(xtxt+1)=N(μ,Σ),对其求导有以下推导过程:
N ( μ , Σ ) = 1 ( 2 π ) d det ⁡ ( Σ ) exp ⁡ [ − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) ] 正态分布的PDF表达式 log ⁡ p θ ( x t ∣ x t + 1 ) = log ⁡ [ 1 ( 2 π ) d det ⁡ ( Σ ) exp ⁡ ( − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) ) ] = log ⁡ ( 1 ( 2 π ) d det ⁡ ( Σ ) ) + log ⁡ ( exp ⁡ ( − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) ) ) = − 1 2 log ⁡ ( ( 2 π ) d det ⁡ ( Σ ) ) − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + C \begin{align*} \mathcal{N}(\mu, \Sigma) &= \frac{1}{\sqrt{(2\pi)^d \det(\Sigma)}} \exp\left[ -\frac{1}{2}(x_t - \mu)^T \Sigma^{-1}(x_t - \mu) \right] \quad \text{正态分布的PDF表达式} \\ \log p_{\theta}(x_t | x_{t+1}) &= \log\left[ \frac{1}{\sqrt{(2\pi)^d \det(\Sigma)}} \exp\left( -\frac{1}{2}(x_t - \mu)^T \Sigma^{-1}(x_t - \mu) \right) \right] \\ &= \log\left( \frac{1}{\sqrt{(2\pi)^d \det(\Sigma)}} \right) + \log\left( \exp\left( -\frac{1}{2}(x_t - \mu)^T \Sigma^{-1}(x_t - \mu) \right) \right) \\ &= -\frac{1}{2} \log\left( (2\pi)^d \det(\Sigma) \right) - \frac{1}{2}(x_t - \mu)^T \Sigma^{-1}(x_t - \mu) \\ &= -\frac{1}{2}(x_t - \mu)^T \Sigma^{-1}(x_t - \mu) + C \quad \tag1 \end{align*} N(μ,Σ)logpθ(xtxt+1)=(2π)ddet(Σ) 1exp[21(xtμ)TΣ1(xtμ)]正态分布的PDF表达式=log[(2π)ddet(Σ) 1exp(21(xtμ)TΣ1(xtμ))]=log((2π)ddet(Σ) 1)+log(exp(21(xtμ)TΣ1(xtμ)))=21log((2π)ddet(Σ))21(xtμ)TΣ1(xtμ)=21(xtμ)TΣ1(xtμ)+C(1)

CG论文中证明可以将条件扩散进行以下转换,其中 Z Z Z是一个归一化常数:
p θ , ϕ ( x t ∣ x t + 1 , y ) = Z p θ ( x t ∣ x t + 1 ) p ϕ ( y ∣ x t ) (2) p_{\theta,\phi}(x_t|x_{t+1},y)=Z p_{\theta}(x_t|x_{t+1})p_{\phi}(y|x_t) \tag2 pθ,ϕ(xtxt+1,y)=Zpθ(xtxt+1)pϕ(yxt)(2)

CG论文中假设与 Σ − 1 \Sigma^{-1} Σ1相比, log ⁡ p ϕ ( y ∣ x t ) \log p_{\phi}(y|x_t) logpϕ(yxt)的曲率较低,可以在 x t = μ x_t=\mu xt=μ附近使用泰勒展开来近似 log ⁡ p ϕ ( y ∣ x t ) \log p_{\phi}(y|x_t) logpϕ(yxt),即
log ⁡ p ϕ ( y ∣ x t ) ≈ log ⁡ p ϕ ( y ∣ x t ) ∣ x t = μ + ( x t − μ ) ∇ x t log ⁡ p ϕ ( y ∣ x t ) ∣ x t = μ = ( x t − μ ) g + C 1 \begin{align*} \log p_{\phi}(y|x_t) & \approx \log p_{\phi}(y|x_t)|_{x_t=\mu} + (x_t-\mu)\nabla_{x_t} \log p_{\phi}(y|x_t)|_{x_t=\mu} \\ & = (x_t-\mu)g + C_1 \tag3 \end{align*} logpϕ(yxt)logpϕ(yxt)xt=μ+(xtμ)xtlogpϕ(yxt)xt=μ=(xtμ)g+C1(3)
其中 g = ∇ x t log ⁡ p ϕ ( y ∣ x t ) ∣ x t = μ g=\nabla_{x_t} \log p_{\phi}(y|x_t)|_{x_t=\mu} g=xtlogpϕ(yxt)xt=μ C 1 C_1 C1是一个常数,故可进一步推导:
log ⁡ ( p θ ( x t ∣ x t + 1 ) p ϕ ( y ∣ x t ) ) ≈ log ⁡ p θ ( x t ∣ x t + 1 ) + log ⁡ p ϕ ( y ∣ x t ) = − 1 2 ( x t − μ ) T Σ − 1 ( x t − μ ) + ( x t − μ ) g + C 2 = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + 1 2 g T Σ g + C 2 = − 1 2 ( x t − μ − Σ g ) T Σ − 1 ( x t − μ − Σ g ) + C 3 = log ⁡ p ( z ) + C 4 , z ∼ N ( μ + Σ g , Σ ) = log ⁡ e C 4 p ( z ) \begin{align*} \log \left(p_{\theta}(x_t | x_{t+1}) p_{\phi}(y | x_t)\right) &\approx \log p_{\theta}(x_t | x_{t+1})+\log p_{\phi}(y|x_t)\\ &= -\frac{1}{2}(x_t - \mu)^T \Sigma^{-1}(x_t - \mu) + (x_t - \mu)g + C_2 \\ &= -\frac{1}{2}(x_t - \mu - \Sigma g)^T \Sigma^{-1}(x_t - \mu - \Sigma g) + \frac{1}{2}g^T \Sigma g + C_2 \\ &= -\frac{1}{2}(x_t - \mu - \Sigma g)^T \Sigma^{-1}(x_t - \mu - \Sigma g) + C_3 \\ &= \log p(z) + C_4, \quad z \sim \mathcal{N}(\mu + \Sigma g, \Sigma) \\ &= \log e^{C_4}p(z) \tag4 \end{align*} log(pθ(xtxt+1)pϕ(yxt))logpθ(xtxt+1)+logpϕ(yxt)=21(xtμ)TΣ1(xtμ)+(xtμ)g+C2=21(xtμΣg)TΣ1(xtμΣg)+21gTΣg+C2=21(xtμΣg)TΣ1(xtμΣg)+C3=logp(z)+C4,zN(μ+Σg,Σ)=logeC4p(z)(4)
其中的常数项 C 4 C_4 C4与公式(2)中的归一化系数 Z Z Z相对应,在求导过程中可以忽略。观察公式(4)可知,条件概率可以通过一个无条件的相似正态分布来近似,只是均值偏差了 Σ g \Sigma g Σg;以下图片中显示了具体的伪代码,其中为梯度引入了一个缩放因子 s s s来控制梯度强度。

在这里插入图片描述

图1

Classifier-Free Guidance

CG需要额外单独训练一个分类模型对扩散模型进行引导,实现比较难。CFG(无分类器引导)通过联合训练条件扩散模型和无条件扩散模型,将条件预测和无条件预测相结合,实现样本质量和多样性之间的平衡,效果与使用CG效果类似。

上一小节中的推导架构属于类DDPM形式,本小节的推导架构属于类SDE形式,即分数匹配类模型。CFG没有训练单独的分类器,而是选择训练一个通过分数估计器 ϵ θ ( z λ ) \epsilon_{\theta}(z_{\lambda}) ϵθ(zλ)参数化的无条件去噪扩散模型 p θ ( z ) p_{\theta}(z) pθ(z),以及一个通过分数估计器 ϵ θ ( z λ , c ) \epsilon_{\theta}(z_{\lambda},c) ϵθ(zλ,c)参数化的条件去噪扩散模型 p θ ( z ∣ c ) p_{\theta}(z|c) pθ(zc);使用一个神经网络实现这两个模型的参数化,实现方法是以一定的概率随机将类别信息 c c c替换为空字符 ∅ \emptyset ,即 ϵ θ ( z λ ) = ϵ θ ( z λ , c = ∅ ) \epsilon_{\theta}(z_{\lambda})=\epsilon_{\theta}(z_{\lambda},c=\emptyset) ϵθ(zλ)=ϵθ(zλ,c=)。可以单独训练两个模型,但是联合训练实现极其简单,不会使训练过程复杂,也不会增加训练参数。训练后,基于条件分数估计和无条件分数估计的线性组合进行采样即可:
ϵ ~ θ ( z λ , c ) = ( 1 + w ) ϵ θ ( z λ , c ) − w ϵ θ ( z λ ) = ϵ θ ( z λ , c ) + w ( ϵ θ ( z λ , c ) − ϵ θ ( z λ ) ) \begin{align*} \tilde{\epsilon}_{\theta}\left(z_{\lambda}, c\right) &= (1 + w)\epsilon_{\theta}\left(z_{\lambda}, c\right) - w\epsilon_{\theta}\left(z_{\lambda}\right) \tag5 \\ &= \epsilon_{\theta}\left(z_{\lambda}, c\right)+w(\epsilon_{\theta}\left(z_{\lambda}, c\right)-\epsilon_{\theta}\left(z_{\lambda}\right)) \tag6 \end{align*} ϵ~θ(zλ,c)=(1+w)ϵθ(zλ,c)wϵθ(zλ)=ϵθ(zλ,c)+w(ϵθ(zλ,c)ϵθ(zλ))(5)(6)
在CG小结中使在无条件的基础上添加 s s s加权的分类器梯度来引导扩散模型。公式(6)本质使用了相同的思想, ϵ θ ( z λ , c ) − ϵ θ ( z λ ) \epsilon_{\theta}\left(z_{\lambda}, c\right)-\epsilon_{\theta}\left(z_{\lambda}\right) ϵθ(zλ,c)ϵθ(zλ)很直观的表示了类别条件 c c c的方向, w w w则与 s s s等价,还是通过添加一个指向条件方向的量来引导扩散模型推理。CFG在训练和推理过程存在不同,两个阶段的伪代码如下所示。
在这里插入图片描述

图2 CFG联合训练

在这里插入图片描述

图3 CFG推理

CFG论文消融实验表明,随着 w ∈ { 0 , 0.1 , 0.2 , . . . , 4 } w \in \{0,0.1,0.2,...,4\} w{0,0.1,0.2,...,4}增加,FID单调下降,IS单调上升,即表明模型保真性和多样性均越来越好。且扩散模型训练时仅需要将相对较小的模型容量(训练过程中条件置空的概率,如0.1或0.2)用于无条件生成,即可实现对样本质量有效的无分类器引导。CFG的缺点是推理速度更低,因为其推理过程中要针对条件和无条件均进行前向传播。

http://www.lryc.cn/news/572088.html

相关文章:

  • Docker 日志
  • 技术文章大纲:SpringBoot自动化部署实战
  • 《状压DP(01矩阵约束问题)》基础概念
  • 计算机网络:(五)信道复用技术,数字传输系统,宽带接入技术
  • 03 面试官考察与 CAP 有关的分布式理论
  • 开源ChatBI :深入解密 Spring AI Alibaba 的中文NL2SQL智能引擎
  • 基于RAGFlow构建Text2SQL的实战教程
  • 密室出逃消消乐小游戏微信流量主小程序开源
  • 如何将文件从安卓设备传输到电脑?
  • XMOS基于边缘AI+DSP+MCU+I/O智算芯片的音频解决方案矩阵引领行业创新潮流
  • 吴恩达机器学习笔记:正则化2
  • 从Excel到知识图谱再到数据分析:数据驱动智能体构建指南
  • SCRM软件数据分析功能使用指南:从数据挖掘到商业决策
  • Spark 技术与实战学习心得:从入门到实践的深度探索
  • CppCon 2017 学习:Effective Qt: 2017 Edition
  • 锂电池保护板测试仪:守护电池安全的幕后保障
  • 【css】设置了margin-top为负数,div被img覆盖的解决方法
  • django调用 paramiko powershell 获取cpu 个数
  • IPv4编址及IPv4路由基础
  • Pinia + Vue Router 权限控制(终极完整版)
  • 无监督学习中的特征选择与检测(FSD)在医疗动线流程优化中的应用
  • 2025-05-05-80x86汇编语言环境配置
  • 使用随机森林实现目标检测
  • AI时代SEO关键词革新
  • 医疗低功耗智能AI网络搜索优化策略
  • 49-Oracle init.ora-PFILE-SPFILE-启动参数转换实操
  • 129. 求根节点到叶节点数字之和 --- DFS +回溯(js)
  • 详解鸿蒙Next仓颉开发语言中的全屏模式
  • 【hadoop】搭建考试环境(单机)
  • LVS+Keepalived+nginx