当前位置: 首页 > news >正文

pytorch 笔记:dist 和 cdist

1 dist

1.1 基本使用方法

torch.dist(input, other, p=2)

计算两个Tensor之间的p-范数

1.2 主要参数

input输入张量
other另一个输入张量
p范数

input 和 other的形状需要是可广播的

1.3 举例

import torchx=torch.randn(4)
x
#tensor([ 1.2698, -0.1209,  0.0462, -1.3271])y=torch.randn(4)
y
#tensor([ 0.6590, -0.8689, -1.0083,  0.5733])torch.dist(x,y)
#tensor(2.3783)
z=torch.randn((2,4))
z
'''
tensor([[-0.9118,  1.8019, -0.0162, -0.1969],[ 0.2998, -0.1147,  1.1427, -0.9425]])
'''torch.dist(x,z)
#tensor(3.4683)

2 cdist

2.1 基本使用方法

torch.cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary')

2.2 主要参数

x1B  × P × M大小的tensor
x2B × R × M 大小的tensor
p范数
compute_mode指定计算欧几里得距离(p=2)时的方法。有三个选项:
  • use_mm_for_euclid_dist_if_necessary:如果 P > 25 或 R > 25,则使用矩阵乘法方法计算欧几里得距离。
  • use_mm_for_euclid_dist:总是使用矩阵乘法方法计算欧几里得距离。
  • donot_use_mm_for_euclid_dist:永不使用矩阵乘法方法计算欧几里得距离。

返回的大小是B × P × R

如果p∈(0,∞),那么这个方法和scipy.spatial.distance.cdist(input,’minkowski’, p=p)是一样的

如果p=0,那么这个方法和scipy.spatial.distance.cdist(input,‘hamming’)是一样的

2.4 使用矩阵乘法速度变慢?

  • 如果数据集较大,或者你有访问高性能计算资源(如GPU),则使用 "use_mm_for_euclid_dist" 可能会更快。
  • 相反,如果数据集较小,或者你的计算资源有限(如只使用CPU),那么 "donot_use_mm_for_euclid_dist" 可能是更好的选择
%%timeit
points1 = torch.rand((5120, 2))
points2 = torch.rand((5120, 2))
torch.cdist(points1, points2, p=2.0, compute_mode="donot_use_mm_for_euclid_dist")
#24 ms ± 4.54 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)%%timeit
points1 = torch.rand((5120, 2))
points2 = torch.rand((5120, 2))
torch.cdist(points1, points2, p=2.0)
#36.7 ms ± 2.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

http://www.lryc.cn/news/256570.html

相关文章:

  • Java的List中的各种浅拷贝和深拷贝问题
  • 20231207_最新已测_Centos7.4安装nginx1.24.0_安装详细步骤---Linux工作笔记066
  • 前端知识笔记(二十六)———React如何像Vue一样将css和js写在同一文件
  • Photoshop Circular Text
  • 深入解析Spring Boot中的注解@PathVariable、@RequestParam、@RequestBody的正确使用
  • Qt Location中加载地图对象
  • 4-Docker命令之docker ps
  • 你在地铁上修过bug吗?
  • CPU、MCU、MPU、DSP、FPGA各是什么?有什么区别?
  • SpringBoot之logback 在Linux系统上启动的时候,设置日志按日期分割并设置指定时间自动清除日志
  • OpenHarmony北向-让更广泛的应用开发者更容易参与
  • 数据结构之归并排序及排序总结
  • 仿windows12网盘,私有云盘部署教程,支持多种网盘
  • 深度学习 时间序列回归学习笔记
  • 【postgresql】ERROR: INSERT has more expressions than target columns
  • Android Kotlin语言下的文件存储
  • Verilog 入门(八)(验证)
  • vue3 vue-router 导航守卫 (五)
  • Git命令---查看远程仓库
  • 12.8作业
  • 算法:有效的括号(入栈出栈)
  • vxworks常用的指令归纳
  • 线性回归实战
  • stm32 使用18B20 测试温度
  • 【Delphi】一个函数实现ios,android震动功能 Vibrate(包括3D Touch 中 Peek 震动等)
  • 国产Type-C PD芯片—接口快充取电芯片
  • pytorch学习6-非线性变换(ReLU和sigmoid)
  • 详解Keras3.0 Models API: Whole model saving loading
  • Spring Cloud Gateway 网关的基础使用
  • 小米手机锁屏时间设置为永不休眠_手机不息屏_保持亮屏