当前位置: 首页 > news >正文

F.binary_cross_entropy、nn.BCELoss、nn.BCEWithLogitsLoss与F.kl_div函数详细解读

提示:有关loss损失函数详细解读,并附源码!!!

文章目录

  • 前言
  • 一、F.binary_cross_entropy()函数解读
    • 1.函数表达
    • 2.函数运用
  • 二、nn.BCELoss()函数解读
    • 1.函数表达
    • 2.函数运用
  • 三、nn.BCEWithLogitsLoss()函数解读
    • 1.函数表达
    • 2.函数运用(logit探索)
    • 3.函数运用(pred探索)
  • 四、F.kl_div()函数解读


前言

最近我在构建蒸馏相关模型,我重温了一下交叉熵相关内容,也使用pytorch相关函数接口调用,我将对F.binary_cross_entropy()、nn.BCELoss()与nn.BCEWithLogitsLoss()函数做一个说明,同时也简单介绍相对熵的蒸馏F.kl_div()函数做一个介绍。

一、F.binary_cross_entropy()函数解读

1.函数表达

F.binary_cross_entropy(input: Tensor,  # 预测输入target: Tensor, # 标签weight: Optional[Tensor] = None, # 权重可选项size_average: Optional[bool] = None,  # 可选项,快被弃用了reduce: Optional[bool] = None,reduction: str = "mean",  # 默认均值或求和等形式
) -> Tensor:

该函数实际是交叉熵运算方式,其中input、target与权重有相同维度(batch,),其中表示可以是任何维度。同时,input为模型预测其每个元素取值范围在[0,1]间。

2.函数运用

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def binary_cross_entropy():input = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])# s = nn.Sigmoid()# pred = s(input)target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])weight = torch.tensor([[0.1, 0.9, 0.1],[0.1, 0.1, 0.9]])output_weight = F.binary_cross_entropy(input, target,weight=weight)  # input取值范围[0,1]output = F.binary_cross_entropy(input, target)  # input取值范围[0,1]print('预测数据:',input)print('标签数据:',target)print('\nbinary_cross_entropy-有权重:{}\t无权重:{}\n'.format(output_weight, output))

结果如下:

预测数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])binary_cross_entropy-有权重:0.12723307311534882	无权重:0.5912299752235413

二、nn.BCELoss()函数解读

1.函数表达

torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')

参数说明:
weight :用于样本加权的权重张量。如果给定,则必须是一维张量,大小等于输入张量的大小。默认值为 None。
reduction :指定如何计算损失值。可选值为 ‘none’、‘mean’ 或 ‘sum’。默认值为 ‘mean’。

此为类,是对F.binary_cross_entropy()函数的调用,也是交叉熵运算方式,其中input、target与权重有相同维度(batch,),其中表示可以是任何维度。同时,input为模型预测其每个元素取值范围在[0,1]间。

2.函数运用

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def bceloss():s = nn.Sigmoid()  # 输出是pred = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])# pred = s(pred)  # 一般会经过sigmoid或softmax方式将其预测转为[0,1]范围的值target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])# nn.BCELoss输入的pred与target的形状必须相同,实际是交叉熵计算,target没有限制bce = nn.BCELoss(reduction='mean')  # size_average参数将被遗弃,使用reduction决定后续操作,有mean sumb = bce(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:',pred)print('标签数据:',target)print('\nbceloss:{}\n'.format(b))

结果如下:

预测数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bceloss:0.5912299752235413

可以看出该函数与上面无权重运行结果一致,实际是对上一个函数进行了类包装,其计算方式和上面函数完全一样。

三、nn.BCEWithLogitsLoss()函数解读

1.函数表达

torch.nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)

参数说明:
weight:用于对每个样本的损失值进行加权。默认值为 None。
reduction:指定如何对每个 batch 的损失值进行降维。可选值为 ‘none’、‘mean’ 和 ‘sum’。默认值为 ‘mean’。
pos_weight:用于对正样本的损失值进行加权。可以用于处理样本不平衡的问题。例如,如果正样本比负样本少很多,可以设置 pos_weight 为一个较大的值,以提高正样本的权重。默认值为 None。

2.函数运用(logit探索)

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def bce_logit_loss():s = nn.Sigmoid()  # 输出是pred = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])bce_logit = nn.BCEWithLogitsLoss(reduction='mean')b_logit = bce_logit(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错pred = s(pred)# nn.BCELoss输入的pred与target的形状必须相同,实际是交叉熵计算,target没有限制bce = nn.BCELoss(reduction='mean')  # size_average参数将被遗弃,使用reduction决定后续操作,有mean sumb = bce(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:', pred)print('标签数据:', target)print('\nbceloss:{}\t bce_with_logit:{} \n'.format(b, b_logit))

结果如下:

预测数据: tensor([[0.6225, 0.7311, 0.6900],[0.5498, 0.5987, 0.6457]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bceloss:0.7678468823432922	 bce_with_logit:0.7678468823432922

可以看出,nn.BCELoss只需多一个nn.Sigmoid()得到的结果和nn.BCEWithLogitsLoss是一致的,说明该类只是多了一个logit过程。

3.函数运用(pred探索)

import torch
import torch.nn.functional as F
def bce_logit_loss():pred = torch.tensor([[5, 1, 8.0], [2, 4, 6.0]])target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])bce_logit = nn.BCEWithLogitsLoss(reduction='mean')b_logit = bce_logit(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:', pred)print('标签数据:', target)print(' bce_with_logit:{} \n'.format( b_logit))

结果如下:

预测数据: tensor([[5., 1., 8.],[2., 4., 6.]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bce_with_logit:3.2446444034576416 

可以看出nn.BCEWithLogitsLoss的输入是可以为实数,它先进行sigmoid处理,将其输入变为[0,1]范围,在进行交叉熵运算,然上面nn.BCELoss与F.binary_cross_entropy则不行。

四、F.kl_div()函数解读

该函数为蒸馏模型使用的函数,我直接给出示列,如下:

def kl_func():logits = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])probs = torch.nn.functional.softmax(logits, dim=1)  # 预测学生模型target_probs = torch.tensor([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]])  # 教师模型loss = F.kl_div(torch.log(probs), target_probs, reduction='batchmean')print('模型输出数据:', logits)print('预测数据:',probs)print('标签数据:',target_probs)print('\nkl_loss:{}\n'.format(loss))

输出结果:

模型输出数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
预测数据: tensor([[0.2501, 0.4123, 0.3376],[0.2693, 0.3289, 0.4018]])
标签数据: tensor([[0.3000, 0.4000, 0.3000],[0.1000, 0.5000, 0.4000]])kl_loss:0.057796258479356766

参考文章:点击这里

http://www.lryc.cn/news/231765.html

相关文章:

  • 后端接口性能优化分析
  • 【ceph】ceph集群中使用多路径(Multipath)方法
  • Xshell+Xftp通过代理的方式访问局域网内网服务器
  • 对盒子中的材料进行计数
  • 科技驱动固定资产管理变革:RFID技术的前沿应用
  • Django路由层之有名分组和无名分组、反向解析、路由分发、伪静态的概念、名称空间、虚拟环境、Django1和Django2的区别
  • 【nlp】2.5 人名分类器实战项目(对比RNN、LSTM、GRU模型)
  • 海康Visionmaster-环境配置:MFC 二次开发环境配置方法
  • 利用EXCEL中的VBA对同一文件夹下的多个数据文件进行特定提取
  • FPGA时序约束(七)文献时序约束实验测试
  • 【数据库开发】DataX开发环境的安装部署(Python、Java)
  • Flutter实践一:package组织
  • SpringCloud微服务:Ribbon负载均衡
  • 【教程】大气化学在线耦合模式WRF/Chem
  • GDS 命令的使用 srvctl service TAF application continuity
  • go 语言之 select
  • 23款奔驰GLC260L升级小柏林音响 全新15个扬声器
  • ssh 免密码登录
  • 小程序使用腾讯位置插件获取当前位置
  • 零基础学Python怎么学习?我来告诉你
  • 开源软件 FFmpeg 生成模型使用图片数据集
  • Linux Shell 通配符 / glob 模式
  • 深入了解域名与SSL证书的关系
  • 计算属性与watch的区别,fetch与axios在vue中的异步请求,单文本组件使用,使用vite创建vue项目,组件的使用方法
  • 2023.11.14 hivesql的容器,数组与映射
  • Android Glide照片宫格RecyclerView,点击SharedElement共享元素动画查看大图,Kotlin(1)
  • SELinux零知识学习八、SELinux策略语言之客体类别和许可(2)
  • deepstream-测试发送AMQP
  • LLMs可以遵循简单的规则吗?
  • 如何挑选护眼灯?光照均匀度、色温、眩光这3点!