当前位置: 首页 > news >正文

PyTorch学习笔记:data.WeightedRandomSampler——数据权重概率采样

PyTorch学习笔记:data.WeightedRandomSampler——数据权重概率采样

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)

功能:按给定的权重(概率)[p0,p1,…,pn−1][p_0,p_1,\dots,p_{n-1}][p0,p1,,pn1]样本索引[0,1,…,n−1][0,1,\dots,n-1][0,1,,n1]采样

输入:

  • weights:采样权重,权重之和不要求为1,该权重需要与每个样本对应起来,即权重数量等于样本数量
  • num_samples:所采样本的数量,可以小于weights的数量
  • replacement:采样策略,如果为True,则代表使用替换采样策略,即可重复对一个样本进行采样;如果为False,则表示不用替换采样策略,即一个样本最多只能被采一次
  • generator:采样过程中的生成器

代码案例

一般用法

from torch.utils.data import WeightedRandomSamplersampler = WeightedRandomSampler([0.1, 0.6, 1.2, 2.9, 0.8, 0.4, 0.8, 1.0, 0.9], 8)
print([i for i in sampler])

输出

这里采样得到的都是样本的索引

[5, 4, 6, 7, 0, 4, 4, 6]

replacement设为TrueFalse的区别

from torch.utils.data import WeightedRandomSamplersampler_t = WeightedRandomSampler([0.1, 0.6, 1.2, 2.9, 0.8, 0.4, 0.8, 1.0, 0.9], 8, replacement=True)
sampler_f = WeightedRandomSampler([0.1, 0.6, 1.2, 2.9, 0.8, 0.4, 0.8, 1.0, 0.9], 8, replacement=False)
print('sampler_t:', [i for i in sampler_t])
print('sampler_f:', [i for i in sampler_f])

输出

# replacement设为True时,会对同一样本多次采样
sampler_t: [6, 1, 6, 6, 3, 3, 8, 4]
# 否则每个样本只采样一次
sampler_f: [7, 0, 2, 4, 1, 3, 8, 5]

官方文档

torch.utils.data.WeightedRandomSampler:https://pytorch.org/docs/stable/data.html?highlight=sampler#torch.utils.data.WeightedRandomSampler

初步完稿于:2022年2月22日

http://www.lryc.cn/news/2240.html

相关文章:

  • SpringMVC对请求参数的处理
  • 12年老外贸的经验分享
  • 电子电路中的各种接地(接地保护与GND)
  • php实现农历公历日期的相互转换
  • 基于SpringBoot的房屋租赁管理系统的设计与实现
  • 一文带你为PySide6编译MySQL插件驱动
  • 图论算法:树上倍增法解决LCA问题
  • Java线程池中submit() 和 execute()方法有什么区别
  • Vue.extend和VueComponent的关系源码解析
  • 【动态规划】01背包问题(滚动数组 + 手画图解)
  • javaEE 初阶 — 超时重传机制
  • 小米5x wlan无法打开解决
  • 负载均衡之最小活跃数算法
  • JavaScript 评测代码运行速度的几种方法
  • Linux 编译器 gcc/g++
  • 2.Java基础【Java面试第三季】
  • Java高级-多线程
  • mysql高级(事务、存储引擎、索引、锁、sql优化、MVCC)
  • Java后端开发功能模块思路
  • CAPL(vTESTStudio) - DoIP - TCP发送_05
  • 使用IntelliJ IDEA搭建datax-web开发环境
  • [SSD固态硬盘技术 14] GC垃圾回收太重要了
  • lamada表达式、stream、collect整理
  • Nacos 入门微服务项目实战
  • 【c++】类和对象:让你明白“面向一个对象有多重要”:构造函数,析构函数,拷贝构造函数的深入学习
  • 职场IT老手教你3步教你玩转可视化大屏设计,让领导眼前一亮!
  • 【光伏功率预测】基于EMD-PCA-LSTM的光伏功率预测模型(Matlab代码实现)
  • 大数据Kylin(二):Kylin安装使用
  • 我们的微服务中为什么需要网关?
  • 互联网医院源码 线上问诊 智慧医院源码 C#源码