当前位置: 首页 > news >正文

深度学习之扩散模型(Diffusion model)

代码解析:正向扩散过程和加噪演示

  1. 引言
    这段代码实现了一个正向扩散过程和加噪演示的功能。通过生成一个特定形状的数据集,并在每个时间步长上应用正向扩散过程和加噪过程,最终展示了数据点在空间中的演变过程。

  2. 数据集生成
    通过 make_swiss_roll 函数生成一个类似瑞士卷的数据集,数据集具有特定的形状和噪声。在这个示例中,数据集被缩放和裁剪,以便更好地展示正向扩散和加噪的效果。

  3. 超参数设定
    设定了一系列超参数,包括时间步数 num_steps 和用于控制正向扩散过程的 alphas 和 betas。这些超参数决定了正向扩散过程中的权重变化,并影响数据点在空间中的演变轨迹。

  4. 正向扩散过程
    定义了一个函数 q_x,用于执行正向扩散过程。该函数接受初始数据点和时间步长作为输入,并根据预先设定的超参数计算出新的数据点。在每个时间步长上,根据权重 alphas 和 betas,将初始数据点与噪声相结合,生成新的数据点。

  5. 加噪演示
    通过循环迭代,每隔一定的时间步长,在图表中展示了数据点的演变过程。在每个演示步骤中,通过调用 q_x 函数生成新的数据点,并在图表中以散点图的形式展示。这样可以清晰地观察到数据点在空间中的变化,从而更好地理解加噪的效果。

  6. 结论
    这段代码展示了如何使用正向扩散过程和加噪过程来生成和演示数据集的变化。通过调整超参数和观察结果,可以更好地理解数据的分布和特征,为后续的数据分析和建模工作提供参考。

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll  # 导入 make_swiss_roll 函数# 构建我们需要的数据集
s_curve, _ = make_swiss_roll(10**4, noise=0.1)
s_curve = s_curve[:, [0, 2]] / 10.0
dataset = torch.Tensor(s_curve).float()# 确定时间步数
num_steps = 100# 确定alpha、beta超参数的值
betas = torch.linspace(-6, 6, num_steps)
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
alphas = 1 - betas
alphas_prod = torch.cumprod(alphas, 0)
alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)# 正向扩散过程——根据x_0和noise计算出任意时刻的x_t值
def q_x(x_0, t):noise = torch.randn_like(x_0)alphas_t = alphas_bar_sqrt[t]alphas_1_m_t = one_minus_alphas_bar_sqrt[t]return (alphas_t * x_0 + alphas_1_m_t * noise)# 演示加噪过程,每20步展示一次结果
num_shows = 20
fig, axs = plt.subplots(2, 10, figsize=(28, 3))
for i in range(num_shows):j = i // 10k = i % 10q_i = q_x(dataset, torch.tensor([i * num_steps // num_shows]))axs[j, k].scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white')axs[j, k].set_axis_off()axs[j, k].set_title(f'$q(\\mathbf{{x}}_{{{i * num_steps // num_shows}}})$')
plt.show()
http://www.lryc.cn/news/320663.html

相关文章:

  • Tomcat Session ID---会话保持
  • Session会话绑定
  • win7、win10、win11 系统能安装的.net framework 版本以
  • RediSearch比Es搜索还快的搜索引擎
  • mybatis-plus 的saveBatch性能分析
  • python异常:pythonIOError异常python打开文件异常
  • 电话机器人语音识别用哪家更好精准度更高。
  • 【Unity动画】Unity如何导入序列帧动画(GIF)
  • uniapp APP 上传文件
  • arcgis数据导出到excel
  • 吴恩达深度学习环境本地化构建wsl+docker+tensorflow+cuda
  • R语言:microeco:一个用于微生物群落生态学数据挖掘的R包:第七:trans_network class
  • ubuntu下在vscode中配置matplotlibcpp
  • Vue面试题,背就完事了
  • centos创建并运行一个redis容器 并支持数据持久化
  • nvm安装和使用保姆级教程(详细)
  • 跳绳计数,YOLOV8POSE
  • 阿里云ecs服务器配置反向代理上传图片
  • 免费阅读篇 | 芒果YOLOv8改进110:注意力机制GAM:用于保留信息以增强渠道空间互动
  • GetLastError()返回值及含义
  • k8s admin 用户生成token
  • 【vscode】vscode重命名变量后多了很多空白行
  • 深度学习实战模拟——softmax回归(图像识别并分类)
  • vue实现element-UI中table表格背景颜色设置
  • RabbitMQ学习总结-消息的可靠性
  • 2024蓝桥杯每日一题(BFS)
  • 力扣思路题:最长特殊序列1
  • c# 的ref 和out
  • ONLYOFFICE文档8.0全新发布:私有部署、卓越安全的协同办公解决方案
  • Mar 14 | Datawhale 01~04 打卡 | Leetcode面试下