当前位置: 首页 > news >正文

采样算法二:去噪扩散隐式模型(DDIM)采样算法详解教程

参考

https://arxiv.org/pdf/2010.02502
在这里插入图片描述


一、背景与动机

去噪扩散隐式模型(DDIM) 是对DDPM的改进,旨在加速采样过程同时保持生成质量。DDPM虽然生成效果优异,但其采样需迭代数百至数千次,效率较低。DDIM通过以下关键创新解决该问题:

  • 非马尔可夫反向过程:打破严格的马尔可夫链假设,允许跳步采样。
  • 确定性生成路径:通过设定参数σ=0,实现确定性采样,减少随机性带来的不确定性。
  • 兼容性:使用与DDPM相同的训练模型,无需重新训练。

二、DDIM与DDPM的核心区别
特性DDPMDDIM
反向过程严格马尔可夫链非马尔可夫,允许跳跃式采样
采样速度慢(需完整迭代所有时间步)快(可跳过中间步,如50步代替1000步)
随机性控制固定方差调度(βₜ)可调参数σₜ(σ=0时为确定性采样)
训练目标需完整训练噪声预测模型直接复用DDPM的预训练模型

三、数学推导与关键公式
1. 前向过程的一致性

DDIM沿用DDPM的前向扩散过程定义,任意时刻( x_t )可表示为:
x t = α ˉ t x 0 + 1 − α ˉ t ϵ , ϵ ∼ N ( 0 , I ) x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, \mathbf{I}) xt=αˉt x0+1αˉt ϵ,ϵN(0,I)
其中 α ˉ t = ∏ i = 1 t α i \bar{\alpha}_t = \prod_{i=1}^t \alpha_i αˉt=i=1tαi, α t = 1 − β t \alpha_t = 1 - \beta_t αt=1βt

2. 反向过程的重新参数化

DDIM将反向过程定义为非马尔可夫链,允许从任意时间步( t )直接推断( x_{t-Δ} )(Δ为跳跃步长)。其核心公式为:
x t − Δ = α ˉ t − Δ ( x t − 1 − α ˉ t ϵ θ ( x t , t ) α ˉ t ) ⏟ 预测的  x 0 + 1 − α ˉ t − Δ − σ t 2 ⋅ ϵ θ ( x t , t ) + σ t z x_{t-Δ} = \sqrt{\bar{\alpha}_{t-Δ}} \underbrace{\left( \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \epsilon_\theta(x_t, t)}{\sqrt{\bar{\alpha}_t}} \right)}_{\text{预测的 } x_0} + \sqrt{1 - \bar{\alpha}_{t-Δ} - \sigma_t^2} \cdot \epsilon_\theta(x_t, t) + \sigma_t z xtΔ=αˉtΔ 预测的 x0 (αˉt xt1αˉt ϵθ(xt,t))+1αˉtΔσt2 ϵθ(xt,t)+σtz

  • 第一项:基于当前 x t x_t xt和预测噪声 ϵ θ \epsilon_\theta ϵθ估计的原始数据 x 0 x_0 x0
  • 第二项:沿预测噪声方向的确定性更新。
  • 第三项:可控的随机噪声项, z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, \mathbf{I}) zN(0,I)
3. 参数σₜ的作用
  • σₜ=0:完全确定性采样(DDIM的标准设定),生成结果唯一。
  • σₜ=√[(1−αₜ₋₁)/(1−αₜ)] · √(1−αₜ/αₜ₋₁):恢复DDPM的采样过程。

四、DDIM采样算法步骤
  1. 输入

    • 预训练噪声预测模型 ϵ θ \epsilon_\theta ϵθ
    • 总时间步 T T T,子序列步数 S S S(如 S = 50 S=50 S=50
    • 方差调度参数 { α t } \{\alpha_t\} {αt}
    • 随机性控制参数 σ t \sigma_t σt
  2. 生成时间步子序列
    选择递减的子序列 { τ 1 , τ 2 , . . . , τ S } \{\tau_1, \tau_2, ..., \tau_S\} {τ1,τ2,...,τS},例如均匀间隔或余弦调度。

  3. 初始化:采样初始噪声 x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, \mathbf{I}) xTN(0,I)

  4. 迭代去噪(从 τ S \tau_S τS τ 1 \tau_1 τ1):

    • 预测噪声 ϵ = ϵ θ ( x τ s , τ s ) \epsilon = \epsilon_\theta(x_{\tau_s}, \tau_s) ϵ=ϵθ(xτs,τs)
    • 估计原始数据
      x ^ 0 = x τ s − 1 − α ˉ τ s ϵ α ˉ τ s \hat{x}_0 = \frac{x_{\tau_s} - \sqrt{1 - \bar{\alpha}_{\tau_s}} \epsilon}{\sqrt{\bar{\alpha}_{\tau_s}}} x^0=αˉτs xτs1αˉτs ϵ
    • 计算下一步状态
      x τ s − 1 = α ˉ τ s − 1 x ^ 0 + 1 − α ˉ τ s − 1 − σ τ s 2 ⋅ ϵ + σ τ s z x_{\tau_{s-1}} = \sqrt{\bar{\alpha}_{\tau_{s-1}}} \hat{x}_0 + \sqrt{1 - \bar{\alpha}_{\tau_{s-1}} - \sigma_{\tau_s}^2} \cdot \epsilon + \sigma_{\tau_s} z xτs1=αˉτs1 x^0+1αˉτs1στs2 ϵ+στsz
      • σ τ s = 0 \sigma_{\tau_s}=0 στs=0时,最后一项消失,变为确定性更新。
  5. 输出 x 0 x_0 x0为生成的数据。


五、伪代码示例
def ddim_sample(model, T, S, alphas_bar, sigmas):# 生成时间步子序列(如从T到0每隔k步取一次)tau = np.linspace(T, 0, S+1, dtype=int)  # 示例:线性间隔x = torch.randn_like(data)  # x_T ~ N(0, I)for s in range(S, 0, -1):t_current = tau[s]t_prev = tau[s-1]# 预测噪声epsilon = model(x, t_current)# 估计x0x0_hat = (x - np.sqrt(1 - alphas_bar[t_current]) * epsilon) / np.sqrt(alphas_bar[t_current])# 计算系数coeff1 = np.sqrt(alphas_bar[t_prev])coeff2 = np.sqrt(1 - alphas_bar[t_prev] - sigmas[t_current]**2)# 更新xx = coeff1 * x0_hat + coeff2 * epsilon + sigmas[t_current] * torch.randn_like(x)return x

http://www.lryc.cn/news/545005.html

相关文章:

  • 北京大学DeepSeek提示词工程与落地场景(PDF无套路免费下载)
  • Hutool - POI:让 Excel 与 Word 操作变得轻而易举
  • IDEAPyCharm安装ProxyAI(CodeGPT)插件连接DeepSeek-R1教程
  • Iceberg Catalog
  • 2025年2月个人工作生活总结
  • vscode java环境中文乱码的问题
  • Java数据结构第十五期:走进二叉树的奇妙世界(四)
  • 【MySQL】CAST()在MySQL中的用法以及其他常用的数据类型转换函数
  • 使用Truffle、Ganache、MetaMask、Vue+Web3完成的一个简单区块链项目
  • 初出茅庐的小李博客之按键驱动库使用
  • 如何调试Linux内核?
  • ECharts组件封装教程:Vue3中的实践与探索
  • NAT 代理服务 内网穿透
  • CAN硬件协议详解
  • 网络安全等级保护:网络安全等级保护基本技术
  • 信刻光盘安全隔离与信息交换系统让“数据摆渡”安全高效
  • 数据结构课程设计(java实现)---九宫格游戏,也称幻方
  • [思考记录]AI时代下,悄然的改变
  • JAVA笔记【一】
  • [Java基础] 常用注解
  • uvm中的run_test作用
  • brew search报错,xcrun:error:invalid active developer path CommandLineTools
  • C#内置委托(Action)(Func)
  • kubernetes 部署项目
  • 《几何原本》命题I.2
  • 【我的 PWN 学习手札】House of Kiwi
  • nvm的学习
  • haclon固定相机位标定
  • stm32(hal库)学习笔记-时钟系统
  • 【Java项目】基于SpringBoot的财务管理系统