当前位置: 首页 > news >正文

深入学习 torch.distributions

0. 引言

前几天分几篇博文精细地讲述了《von Mises-Fisher 分布》, 以及相应的 PyTorch 实现《von Mises-Fisher Distribution (代码解析)》, 其中以 Uniform 分布为例简要介绍了 torch.distributions 包的用法. 本以为已经可以了, 但这两天看到论文 The Power Spherical distribution 的代码, 又被其实现分布的方式所吸引.

Power Spherical 分布与 von Mises Fisher 分布类似, 只不过将后者概率密度函数中的指数函数换成了多项式函数: f p ( x ; μ , κ ) ∝ e x p ( κ μ ⊺ x ) ⇓ f p ( x ; μ , κ ) ∝ ( 1 + μ ⊺ x ) κ \begin{aligned} f_p(\bm{x}; \bm{\mu}, \kappa) &\propto exp(\kappa \bm{\mu}^\intercal \bm{x}) \\ &\Downarrow\\ f_p(\bm{x}; \bm{\mu}, \kappa) &\propto (1+\bm{\mu}^\intercal \bm{x})^\kappa \\ \end{aligned} fp(x;μ,κ)fp(x;μ,κ)exp(κμx)(1+μx)κ 采样框架基本一致, 且这么做可以使边缘 t t t 的线性变换 t + 1 2 ∼ B e t a ( p − 1 2 + κ , p − 1 2 ) \frac{t+1}{2} \sim Beta(\frac{p-1}{2}+\kappa, \frac{p-1}{2}) 2t+1Beta(2p1+κ,2p1), 从而避免了接受-拒绝采样过程.

当然, 按照之前的 VonMisesFisher 的写法, 这个 t 的采样大概是这样:

z = beta.sample(sample_shape)
t = 2 * z - 1

但现在我遇到了这种写法:

class MarginalTDistribution(tds.TransformedDistribution):arg_constraints = {'dim': constraints.positive_integer,'scale': constraints.positive,}has_rsample = Truedef __init__(self, dim, scale, validate_args=None):self.dim = dimself.scale = scalesuper().__init__(tds.Beta(  # 用 Beta 分布转换, z 服从 Beta(α+κ,β)(dim - 1) / 2 + scale, (dim - 1) / 2, validate_args=validate_args),transforms=tds.AffineTransform(loc=-1, scale=2),  # t=2z-1 是想要的边缘分布随机数)

然后就可以进行对 t t t 的采样了.

上述代码的解构图; 浅蓝色代表抽象基类, 绿色代表实类; 虚线代表继承, 实线代表参数输入

我们可以看到其基本架构, 本文将详细解析其内部的具体细节.

http://www.lryc.cn/news/352415.html

相关文章:

  • Java中的判断校验非空问题
  • webman使用summernote富文本编辑器
  • jQuery里添加事件 (代码)
  • Java数组的使用
  • 如何参与github开源项目并提交PR
  • 拼多多携手中国农业大学,投建陕西佛坪山茱萸科技小院
  • 技术前沿 |【自回归视觉模型ImageGPT】
  • Manjaro linux install RedisGUI (RedisInsight)亲测2024-5-25
  • debian/control文件中常见字段的介绍
  • c++题目_农场和奶牛
  • DDD领域设计在“图生代码”中的应用实践
  • LabVIEW舱段测控系统开发
  • [leetcode]第 n个丑数
  • STM32-电灯,仿真
  • 《SpringBoot》系列文章目录
  • 牛客小白月赛94VP
  • php 亚马逊AWS-S3对象存储上传文件
  • electron-01 基础及NPM相关配置
  • Foxit PDF Editor Pro福昕PDF编辑器Pro:重塑您的文档编辑体验
  • VUE 页面生命周期基本知识点
  • windows查看mysql的版本(三种方法)
  • Redis批量删除指定前缀的key
  • 机器学习实验------Adaboost算法
  • 点云处理中阶 Octree模块
  • Nginx实现负载均衡与故障检查自动切换
  • 2024年学浪视频怎么下载到手机相册
  • 【北京市政府网_注册安全分析报告】
  • 工作中的冲突,职场人士应如何化解
  • 企业级大数据平台建设方案
  • HTML语义化标签:为何它们如此重要?