当前位置: 首页 > news >正文

〖open-mmlab: MMDetection〗解析文件:mmdet/models/losses/cross_entropy_loss.py

目录

  • 深入解析MMDetection中的CrossEntropyLoss及其应用
    • 1. 概述
    • 2. 核心函数
      • 2.1 cross_entropy
        • 2.1.1 函数定义和参数说明
        • 2.1.2 函数体
        • 2.1.3 总结
      • 2.2 binary_cross_entropy
        • 2.2.1 `_expand_onehot_labels`函数
        • 2.2.2 `binary_cross_entropy`函数
        • 2.2.3 总结
      • 2.3 mask_cross_entropy
        • 2.3.1函数定义和参数说明
        • 2.3.2函数体
        • 2.3.3总结
    • 3. CrossEntropyLoss类
      • 3.1 CrossEntropyLoss 类详解
        • 3.1.1 类定义
        • 3.1.2 初始化方法
        • 3.1.3 属性设置
        • 3.1.4 `extra_repr` 方法
        • 3.1.5 前向传播方法
        • 3.1.6 内部逻辑
      • 3.2 示例
    • 4. 应用示例
    • 5. 总结
    • 参考文献

深入解析MMDetection中的CrossEntropyLoss及其应用

在目标检测和分类任务中,交叉熵损失(CrossEntropyLoss)是评估模型预测与真实标签差异的关键指标。MMDetection框架提供了灵活的损失函数实现,以支持不同的训练需求。本文将详细解析CrossEntropyLoss及其在MMDetection中的应用,并探讨其在模型训练中的作用。

1. 概述

CrossEntropyLoss是MMDetection中用于计算交叉熵损失的模块。它支持多种配置选项,包括对数损失的权重、忽略的标签索引等,以适应不同的训练场景。

2. 核心函数

2.1 cross_entropy

cross_entropy函数是计算交叉熵损失的核心函数。它接收模型预测、真实标签和其他可选参数,并返回计算得到的损失值。

cross_entropy函数是用于计算交叉熵损失的函数,通常用于多分类问题。在深度学习中,交叉熵损失是衡量模型预测概率分布与真实标签概率分布差异的常用方法。以下是对这个函数的详细逐行解析。

2.1.1 函数定义和参数说明
def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None, class_weight=None, ignore_index=-100, avg_non_ignore=False):
  • pred: 预测值,形状为 (N, C) 的张量,其中 N 是批次大小,C 是类别数量。
  • label: 真实标签,形状为 (N,) 的张量,每个元素是对应样本的类别索引。
  • weight: 每个样本的损失权重,形状为 (N,) 的张量。
  • reduction: 指定损失计算后的缩减方式,可以是 “none”、“mean” 或 “sum”。
  • avg_factor: 用于平均损失的因子,通常用于在损失中考虑无效(如填充)的样本。
  • class_weight: 每个类别的权重,用于加权损失计算。
  • ignore_index: 忽略的标签索引,对于这个索引的样本不计入损失计算。
  • avg_non_ignore: 是否仅在非忽略目标上平均损失。
2.1.2 函数体
ignore_index = -100 if ignore_index is None else ignore_index
  • 这行代码设置 ignore_index 的默认值为 -100,如果用户没有指定 ignore_index,则使用默认值。
loss = F.cross_entropy(pred,label,weight=class_weight,reduction='none',ignore_index=ignore_index)
  • 使用 PyTorch 的 F.cross_entropy 函数计算每个样本的交叉熵损失。这里设置 reduction='none' 以获取每个样本的损失,而不是直接求平均或求和。
if (avg_factor is None) and avg_non_ignore and reduction == 'mean':avg_factor = label.numel() - (label == ignore_index).sum().item()
  • 如果 avg_factor 未指定,且 avg_non_ignore 为 True,且 reduction 为 “mean”,则计算平均因子。这里,avg_factor 是用来计算损失的平均值时考虑的有效样本数量。
if weight is not None:weight = weight.float()
  • 如果指定了每个样本的损失权重,则将其转换为浮点数。
loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
  • 使用 weight_reduce_loss 函数应用样本权重,并根据 reduction 参数指定的方法缩减损失。avg_factor 用于在平均损失时考虑有效样本数量。
return loss
  • 返回计算得到的损失。
2.1.3 总结

cross_entropy函数是计算多分类问题交叉熵损失的关键函数。它支持对损失进行加权、忽略特定标签的样本,并根据指定的方法缩减损失。这些特性使得该函数在处理不平衡数据集或需要特殊处理某些样本时非常有用。

2.2 binary_cross_entropy

binary_cross_entropy函数用于计算二元交叉熵损失,特别适用于目标检测中正负样本的分类。

这两个函数是MMDetection中用于处理标签和计算二元交叉熵损失的关键工具。下面将逐行解析这两个函数。

2.2.1 _expand_onehot_labels函数
def _expand_onehot_labels(labels, label_weights, label_channels, ignore_index):"""Expand onehot labels to match the size of prediction."""
  • 这个函数用于将标签扩展为与预测尺寸相匹配的一位有效编码(one-hot encoding)形式。
    bin_labels = labels.new_full((labels.size(0), label_channels), 0)
  • 创建一个形状为(labels.size(0), label_channels)的全零张量,用于存储一位有效编码的标签。
    valid_mask = (labels >= 0) & (labels != ignore_index)
  • 创建一个有效掩码,其中标签值大于或等于0且不等于ignore_index的位置为True。
    inds = torch.nonzero(valid_mask & (labels < label_channels), as_tuple=False)
  • 找到所有有效且小于label_channels的标签索引。
    if inds.numel() > 0:bin_labels[inds, labels[inds]] = 1
  • 如果存在有效的索引,则在bin_labels张量中将这些索引对应的位置设置为1。
    valid_mask = valid_mask.view(-1, 1).expand(labels.size(0),label_channels).float()
  • valid_mask重塑并扩展到与bin_labels相同的形状,并转换为浮点数。
    if label_weights is None:bin_label_weights = valid_maskelse:bin_label_weights = label_weights.view(-1, 1).repeat(1, label_channels)bin_label_weights *= valid_mask
  • 如果没有给定label_weights,则使用valid_mask作为权重;否则,将label_weights扩展并与valid_mask相乘。
    return bin_labels, bin_label_weights, valid_mask
  • 返回扩展后的一位有效编码标签、标签权重和有效掩码。
2.2.2 binary_cross_entropy函数
def binary_cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None, class_weight=None, ignore_index=-100, avg_non_ignore=False):
  • 这个函数用于计算二元交叉熵损失。
    ignore_index = -100 if ignore_index is None else ignore_index
  • 设置ignore_index的默认值为-100。
    if pred.dim() != label.dim():label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.size(-1), ignore_index)else:valid_mask = ((label >= 0) & (label != ignore_index)).float()if weight is not None:weight = weight * valid_maskelse:weight = valid_mask
  • 如果predlabel的维度不同,则调用_expand_onehot_labels函数扩展标签和权重;否则,创建有效掩码并根据需要调整weight
    if (avg_factor is None) and avg_non_ignore and reduction == 'mean':avg_factor = valid_mask.sum().item()
  • 如果需要在非忽略元素上平均损失,则计算平均因子。
    weight = weight.float()loss = F.binary_cross_entropy_with_logits(pred, label.float(), pos_weight=class_weight, reduction='none')
  • 将权重转换为浮点数,然后使用torch.nn.functional.binary_cross_entropy_with_logits计算未缩减的元素级损失。
    loss = weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor)
  • 使用weight_reduce_loss函数根据reduction参数指定的方法对损失进行缩减。
    return loss
  • 返回计算得到的损失。
2.2.3 总结

这两个函数是处理标签和计算损失的关键步骤,特别是在处理不平衡数据集或需要忽略某些标签时非常有用。_expand_onehot_labels函数负责将标签转换为适合损失计算的形式,而binary_cross_entropy函数则负责实际的损失计算。通过这些函数,MMDetection能够灵活地处理各种复杂的训练场景。

2.3 mask_cross_entropy

mask_cross_entropy函数用于计算掩码的交叉熵损失,常用于实例分割任务中。

在目标检测和实例分割任务中,mask_cross_entropy函数用于计算预测掩码与真实掩码之间的交叉熵损失。这个函数特别适用于处理每个目标的二进制掩码。下面是对这个函数的逐行解析。

2.3.1函数定义和参数说明
def mask_cross_entropy(pred, target, label, reduction='mean', avg_factor=None, class_weight=None, ignore_index=None, **kwargs):
  • pred: 预测的掩码,形状为 (N, C, *),其中 N 是样本数量,C 是类别数,* 表示任意维度的形状。
  • target: 真实掩码标签,形状为 (N, *),与 pred 的非类别维度相同。
  • label: 每个目标的类别标签,用于从 pred 中选择对应类别的掩码。
  • reduction: 指定损失计算后的缩减方式,可以是 “none”、“mean” 或 “sum”。
  • avg_factor: 用于平均损失的因子,通常用于在损失中考虑无效(如填充)的样本。
  • class_weight: 每个类别的权重,用于加权损失计算。
  • ignore_index: 忽略的标签索引,在此函数中不支持。
2.3.2函数体
assert ignore_index is None, 'BCE loss does not support ignore_index'
  • 这行代码确保 ignore_index 参数为 None,因为二元交叉熵损失(BCE loss)不支持忽略特定索引。
assert reduction == 'mean' and avg_factor is None
  • 这行代码确保 reduction 参数为 “mean” 且 avg_factorNone。这是为了简化实现,避免处理复杂的缩减逻辑。
num_rois = pred.size()[0]
  • 获取预测掩码的数量(即样本数量)。
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
  • 创建一个从0到 num_rois 的整数序列,用于索引 pred
pred_slice = pred[inds, label].squeeze(1)
  • pred 中选择每个样本对应类别的预测掩码。label 包含了每个样本的类别索引,squeeze(1) 用于去除长度为1的维度,使 pred_slice 的形状为 (N, *)
return F.binary_cross_entropy_with_logits(pred_slice, target, weight=class_weight, reduction='mean')[None]
  • 使用 torch.nn.functional.binary_cross_entropy_with_logits 计算二元交叉熵损失。这里 pred_slice 是预测值,target 是真实值,weight 是类别权重(如果有的话)。
  • reduction='mean' 指定损失的平均方式。
  • [None] 用于将输出转换为形状 (1,) 的张量,以符合 MMDetection 中损失函数的输出格式。
2.3.3总结

mask_cross_entropy 函数是一个专门用于计算预测掩码与真实掩码之间交叉熵损失的函数。它通过选择每个样本对应类别的预测掩码,并计算与真实掩码的二元交叉熵损失。这个函数在实例分割任务中非常有用,尤其是在需要对每个目标的掩码进行分类的场景中。

3. CrossEntropyLoss类

CrossEntropyLoss类封装了交叉熵损失的计算逻辑,使其可以作为模型的一个组件被轻松集成。

@MODELS.register_module()
class CrossEntropyLoss(nn.Module):def __init__(self, use_sigmoid=False, use_mask=False, reduction='mean', class_weight=None, ignore_index=None, loss_weight=1.0, avg_non_ignore=False):...

3.1 CrossEntropyLoss 类详解

CrossEntropyLoss 类是用于计算交叉熵损失的一个自定义PyTorch模块。它可以适应多种场景下的分类任务,包括多分类、二分类以及掩码损失。下面我们将通过具体的例子来详细解析该类的实现及其工作原理。

3.1.1 类定义
@MODELS.register_module()
class CrossEntropyLoss(nn.Module):

这里 @MODELS.register_module() 是一个装饰器,通常用于注册自定义模块,使得该模块可以在配置文件中被方便地引用和实例化。

3.1.2 初始化方法
def __init__(self,use_sigmoid=False,use_mask=False,reduction='mean',class_weight=None,ignore_index=None,loss_weight=1.0,avg_non_ignore=False):
  • use_sigmoid: 布尔值,指示输出是否使用sigmoid激活函数,默认为False,意味着使用softmax。
  • use_mask: 布尔值,指示是否使用掩码交叉熵损失,默认为False。
  • reduction: 字符串,指定损失函数的缩减方式,默认为’mean’,即计算平均损失。可选值还有’none’(不缩减)和’sum’(求和)。
  • class_weight: 列表或数组,每个类别的权重,默认为None,即所有类别权重相同。
  • ignore_index: 整数或None,需要忽略的标签索引,默认为None。
  • loss_weight: 浮点数,损失函数的整体权重,默认为1.0。
  • avg_non_ignore: 布尔值,决定是否只在非忽略的目标上平均损失,默认为False。
3.1.3 属性设置

在初始化方法中,设置了多个类属性,并根据传入的参数选择不同的交叉熵损失计算方法:

  • cls_criterion: 根据 use_sigmoiduse_mask 的值来选择具体的损失计算函数。可能的函数有 binary_cross_entropymask_cross_entropycross_entropy
3.1.4 extra_repr 方法
def extra_repr(self):s = f'avg_non_ignore={self.avg_non_ignore}'return s

该方法返回一个额外的表示字符串,通常用于打印类的额外信息。在这个例子中,返回的是 avg_non_ignore 的状态。

3.1.5 前向传播方法
def forward(self,cls_score,label,weight=None,avg_factor=None,reduction_override=None,ignore_index=None,**kwargs):
  • cls_score: 模型的预测输出。
  • label: 真实标签。
  • weight: 样本权重,默认为None。
  • avg_factor: 平均因子,默认为None。
  • reduction_override: 指定的缩减方式,默认为None,表示使用初始化时指定的方式。
  • ignore_index: 需要忽略的标签索引,默认为None,表示使用初始化时指定的值。
  • **kwargs: 其他关键字参数。
3.1.6 内部逻辑
  1. 参数验证

    assert reduction_override in (None, 'none', 'mean', 'sum')
    

    确保 reduction_override 参数的有效性。

  2. 确定缩减方式

    reduction = (reduction_override if reduction_override else self.reduction)
    

    根据 reduction_override 参数或初始化时设置的 reduction 属性确定实际使用的缩减方式。

  3. 确定忽略索引

    if ignore_index is None:ignore_index = self.ignore_index
    

    如果 ignore_index 未被显式指定,则使用初始化时设置的值。

  4. 处理类别权重

    if self.class_weight is not None:class_weight = cls_score.new_tensor(self.class_weight, device=cls_score.device)
    else:class_weight = None
    

    如果提供了类别权重,将其转换为与 cls_score 相同设备上的张量。

  5. 计算损失

    loss_cls = self.loss_weight * self.cls_criterion(cls_score,label,weight,class_weight=class_weight,reduction=reduction,avg_factor=avg_factor,ignore_index=ignore_index,avg_non_ignore=self.avg_non_ignore,**kwargs)
    

    根据选择的交叉熵损失计算函数 cls_criterion 计算损失,并应用损失权重。

  6. 返回损失

    return loss_cls
    

    返回计算得到的损失值。

3.2 示例

假设我们有一个简单的二分类问题,其中 cls_score 是模型的输出,label 是真实的标签,我们想要计算交叉熵损失。

import torch
import torch.nn.functional as F# 假设的模型输出
cls_score = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
# 假设的真实标签
label = torch.tensor([0, 1])# 实例化 CrossEntropyLoss
loss_func = CrossEntropyLoss(use_sigmoid=True)# 计算损失
loss = loss_func(cls_score, label)
print("Loss:", loss.item())

在这个例子中,我们使用了sigmoid激活函数,并假设 cls_score 是两个样本的二分类概率预测。label 表示这两个样本的真实类别分别为0和1。loss_func 实例化了一个使用sigmoid的交叉熵损失函数,最后我们计算并打印了损失值。

通过以上分析,我们可以看到 CrossEntropyLoss 类是如何灵活地适应不同场景下的交叉熵损失计算需求的。

4. 应用示例

以下是如何在实际模型训练中使用CrossEntropyLoss的示例。

# 初始化损失函数
criterion = CrossEntropyLoss(use_sigmoid=True, reduction='mean')# 假设有预测和标签
pred = torch.randn(10, 2)  # 10个样本,2个类别
label = torch.empty(10, dtype=torch.long).random_(2)# 计算损失
loss = criterion(pred, label)

5. 总结

CrossEntropyLoss是MMDetection中实现交叉熵损失的关键组件,它支持多种配置选项,以适应不同的训练需求。通过合理配置,可以有效地优化模型在目标检测和分类任务中的表现。

参考文献

  1. PyTorch官方文档

本文详细介绍了CrossEntropyLoss的实现及其在MMDetection中的应用,希望对目标检测和分类任务的研究者和开发者有所帮助。

http://www.lryc.cn/news/435652.html

相关文章:

  • 【PyTorch单点知识】torch.nn.Embedding模块介绍:理解词向量与实现
  • Jedis 操作 Redis 数据结构全攻略
  • ctf.show靶场ssrf攻略
  • 在 PyTorch 中,如何使用 `pack_padded_sequence` 来提高模型训练的效率?
  • Gossip协议
  • 数据结构————双向链表
  • 55 - I. 二叉树的深度
  • Redis——初识Redis
  • Xshell or Xftp提示“要继续使用此程序,您必须应用最新的更新或使用新版本”
  • table用position: sticky固定多层表头,滑动滚动条border边框透明解决方法
  • 基于飞桨paddle2.6.1+cuda11.7+paddleRS开发版的目标提取-道路数据集训练和预测代码
  • 数学建模笔记—— 整数规划和0-1规划
  • [001-03-007].第26节:分布式锁迭代3->优化基于setnx命令实现的分布式锁-防锁的误删
  • 【Unity踩坑】为什么有Rigidbody的物体运行时位置会变化
  • NGINX开启HTTP3,给web应用提个速
  • 秋招季!别浮躁!
  • Java的时间复杂度和空间复杂度和常见排序
  • Qt 学习第十天:标准对话框 页面布局
  • 体育数据API纳米足球数据API:足球数据接口文档API示例⑩
  • [数据集][目标检测]高铁受电弓检测数据集VOC+YOLO格式1245张2类别
  • Vuex:深入理解所涉及的几个问题
  • vue原理分析(六)研究new Vue()
  • 滑动窗口+动态规划
  • vscode配置django环境并创建django项目
  • WebGL系列教程四(绘制彩色三角形)
  • 通过mxGraph在ARMxy边缘计算网关上实现工业物联网
  • GEE案例:利用sentinel-1数据进行洪水监测分析(直方图统计)
  • QT 联合opencv 易错点
  • 例如/举例的使用方法 ,e.g., 以及etc的使用方法
  • 20240902-VSCode-1.19.1-部署vcpkg-win10-22h2