当前位置: 首页 > news >正文

pytorch 笔记:KLDivLoss

1 介绍

对于具有相同形状的张量 ypred​ 和 ytrue(ypred​ 是输入,ytrue​ 是目标),定义逐点KL散度为:

为了在计算时避免下溢问题,此KLDivLoss期望输入在对数空间中。如果log_target=True,则目标也在对数空间。

2 参数

reduction

reduction= “mean”不返回真正的KL散度值,reduction= “batchmean”才是

log_target指定目标是否在对数空间中

3 举例

import torch
import torch.nn as nninput = torch.tensor([[0.5, -0.5, 0.1], [0.1, -0.2, 0.3]], requires_grad=True)target = torch.tensor([[0.7, 0.2, 0.1], [0.1, 0.5, 0.4]])loss_function = nn.KLDivLoss(reduction='batchmean')
loss = loss_function(input, target)
print(loss)
#tensor(-1.0176, grad_fn=<DivBackward0>)

等价手动形式:

target*(target.log()-input)
'''
tensor([[-0.5997, -0.2219, -0.2403],[-0.2403, -0.2466, -0.4865]], grad_fn=<MulBackward0>)
'''#这里的每个元素计算方式为:
'''
tensor([[-0.5997, -0.2219, -0.2403],[-0.2403, -0.2466, -0.4865]], grad_fn=<MulBackward0>)
'''torch.sum(target*(target.log()-input))/2
#tensor(-1.0176, grad_fn=<DivBackward0>)

http://www.lryc.cn/news/211192.html

相关文章:

  • 父子项目打包发布至私仓库
  • 汽车网络安全--ECU的安全更新
  • NLP之搭建RNN神经网络
  • Android问题笔记四十三:JNI 开发如何快速定位崩溃问题
  • 机器学习 | 决策树算法
  • javascript中各种风骚的代码
  • el-tree横向纵向滚动条
  • STM32G030F6P6 芯片实验 (一)
  • Wpf 使用 Prism 实战开发Day01
  • 6G关键新兴技术- 智能超表面(RIS)技术演进
  • 【redhat9.2】搭建Discuz-X3.5网站
  • 算法篇 : 并查集
  • AM@微积分基本定理@微积分第二基本定理
  • goland常用快捷键
  • CSDN写文章时常见问题及技巧
  • JVM虚拟机详解
  • Go 怎么操作 OSS 阿里云对象存储
  • vue3 Suspense组件
  • NlogPrismWPF
  • 文件上传漏洞(2), 文件上传实战绕过思路, 基础篇
  • 论文阅读 - Hidden messages: mapping nations’ media campaigns
  • [AutoSAR系列] 1.3 AutoSar 架构
  • 迁移学习 - 微调
  • 09 用户态跟踪:如何使用eBPF排查应用程序?
  • 深入浅出排序算法之堆排序
  • Linux 命令(11)—— tcpdump
  • 8.自定义组件布局和详解Context上下文
  • 几个Web自动化测试框架的比较:Cypress、Selenium和Playwright
  • Android Studio中配置aliyun maven库
  • 记录使用阿里 ARoute 遇到的坑