当前位置: 首页 > news >正文

PyTorch 简单易懂的实现 CosineSimilarity 和 PairwiseDistance - 距离度量的操作

目录

torch.nn子模块Distance Functions解析

nn.CosineSimilarity

功能

主要参数

输入和输出的形状

使用示例

nn.PairwiseDistance

功能

主要参数

输入和输出的形状

使用示例

总结


torch.nn子模块​​​​​​​Distance Functions解析

nn.CosineSimilarity

torch.nn.CosineSimilarity 是 PyTorch 中的一个模块,用于计算两个输入之间的余弦相似度。余弦相似度是一种常用的相似度度量方式,特别适用于高维空间中的向量,如在自然语言处理、推荐系统等领域中用于比较文档或用户偏好的相似性。以下是对 CosineSimilarity 模块的功能、用法和特点的详细说明。

功能

  • 计算余弦相似度:该模块计算两个输入向量在指定维度上的余弦相似度。
  • 多维支持:可以在多维张量上操作,并在指定的维度 dim 上计算相似度。

主要参数

  • dim(int,可选):指定计算相似度的维度。默认值为1。
  • eps(float,可选):为了避免除以零,引入的一个小的数值。默认值为1e-8。

输入和输出的形状

  • 输入:两个输入张量的形状应为 (*1, D, *2),其中 D 是在 dim 维度上的大小。这两个张量在 dim 维度上的大小应该相同,而在其他维度上可以广播。
  • 输出:输出张量的形状为 (*1, *2),不包含 dim 维度。

使用示例

import torch
import torch.nn as nn# 创建输入张量
input1 = torch.randn(100, 128)
input2 = torch.randn(100, 128)# 创建 CosineSimilarity 实例
cos = nn.CosineSimilarity(dim=1, eps=1e-6)# 计算两个输入之间的余弦相似度
output = cos(input1, input2)

在这个示例中,CosineSimilarity 用于计算两个 100x128 维度张量在第一个维度(dim=1)上的余弦相似度。这种方法在比较两组高维数据的相似性时非常有用,如比较不同文档的语义相似度或用户偏好的相似度。

nn.PairwiseDistance

torch.nn.PairwiseDistance 是 PyTorch 中的一个模块,用于计算输入向量对之间的成对距离,或者输入矩阵列之间的成对距离。该模块主要用于计算两组数据之间的距离,例如在聚类、近邻搜索等应用中。接下来,我将详细介绍 PairwiseDistance 模块的功能、用法和特点。

功能

  • 成对距离计算:计算两个输入之间的成对距离,通常使用 p-范数。
  • 适用于多维数据:可以处理高维数据,计算多组数据之间的成对距离。

主要参数

  • p(实数,可选):范数的度数,可以是负数。默认值为2,表示使用欧几里得距离。
  • eps(浮点数,可选):用于避免除零的小数。默认值为1e-6。
  • keepdim(布尔值,可选):确定是否保持向量维度。默认值为 False。

输入和输出的形状

  • 输入:两个输入张量的形状可以是 (N, D)(D),其中 N 是批次维度,D 是向量维度。
  • 输出:基于输入维度的输出形状为 (N)()。如果 keepdim 为 True,则输出形状为 (N,1)(1)

使用示例

import torch
import torch.nn as nn# 创建 PairwiseDistance 实例
pdist = nn.PairwiseDistance(p=2)# 创建两组输入数据
input1 = torch.randn(100, 128)
input2 = torch.randn(100, 128)# 计算成对距离
output = pdist(input1, input2)

 在这个示例中,PairwiseDistance 用于计算两个 100x128 维度张量之间的欧几里得距离(p=2)。这种方法适用于需要比较两组数据之间距离的场景,如在机器学习中的距离度量、近邻搜索或者在计算损失函数时评估预测与实际值之间的距离。

总结

 本篇博客全面探讨了 PyTorch 框架中的两个关键的距离函数模块:nn.CosineSimilaritynn.PairwiseDistancenn.CosineSimilarity 模块专注于计算两个高维数据集之间的余弦相似度,适用于评估文档、用户偏好等在特征空间中的相似性。而 nn.PairwiseDistance 模块提供了一种计算两组数据点之间成对欧几里得距离的有效方式,这在聚类、近邻搜索或预测与实际值之间距离度量的场景中非常有用。这两个模块共同构成了在多种机器学习和数据科学应用中处理和比较数据集的基础工具。

http://www.lryc.cn/news/275891.html

相关文章:

  • app加载不到aar中的so库
  • vue-springboot基于java的实验室安全考试系统
  • mysql+关掉密码过期
  • 实际项目中的环形缓冲区
  • 输出回文数-第11届蓝桥杯选拔赛Python真题精选
  • 内存溢出会导致模块测试正常,植入系统失败
  • 【taro react】 ---- QRCode 二维码生成
  • rk3566 armbian修复usb2.0并挂载U盘
  • 猫头虎博主第9期赠书活动:《YOLO目标检测》计算机AI视觉实战YOLO人工智能目标检测与跟踪图像处理深度学习图像检测书籍
  • python 如何将英语单词翻译成中文
  • Linux_CentOS_7.9_MySQL_5.7配置数据库服务开机自启动之简易记录
  • js实现拖动盒子查看内容 内容拖动
  • [C#]winform利用seetaface6实现C#人脸检测活体检测口罩检测年龄预测性别判断眼睛状态检测
  • c++ execl 执行 重定向
  • uni-app中实现元素拖动
  • Java系列-Class.forName和ClassLoader.loadClass的区别
  • 找不到模块 “path“ 或其相对应的类型声明
  • Linux第17步_安装SSH服务
  • C语言—数据类型
  • 静态网页设计——多彩贵州(HTML+CSS+JavaScript)(dw、sublime Text、webstorm、HBuilder X)
  • unity PDFRender Curved UI3.3
  • 基于深度学习的停车位关键点检测系统(代码+原理)
  • C#,入门教程(09)——运算符的基础知识
  • 企业出海数据合规:GDPR中的个人数据与非个人数据之区分
  • 如何在Ubuntu搭建Emlog博客站点并发布至公网可随时远程访问管理界面——“cpolar内网穿透”
  • 【金猿CIO展】是石科技CIO侯建业:算力产业赋能,促进数字经济建设
  • TypeScript 类
  • Oracle分区表
  • 【leetcode】力扣算法之旋转图像【难度中等】
  • 【Java集合类篇】HashMap的数据结构是怎样的?