当前位置: 首页 > news >正文

混淆矩阵、准确率、查准率、查全率、DSC、IoU、敏感度的计算

1.背景介绍

在训练的模型的时候,需要评价模型的好坏,就涉及到混淆矩阵、准确率、查准率、查全率、DSC、IoU、敏感度的计算。

2、混淆矩阵的概念

所谓的混淆矩阵如下表所示:

TP:真正类,真的正例被预测为正例

FN:假负类,样本为正例,被预测为负类

FP:假正类 ,原本实际为负,但是被预测为正例

TN:真负类,真的负样本被预测为负类。

从混淆矩阵当中,可以得到更高级的分类指标:Accuracy(准确率),Precision(查准率),Recall(查全率),Specificity(特异性),Sensitivity(灵敏度)。

3. 常用的分类指标

3.1 Accuracy(准确率)

不管是哪个类别,只要预测正确,其数量都放在分子上,而分母是全部数据量。常用于表示模型的精度,当数据类别不平衡时,不能用于模型的评价。

Accuracy=\frac{TP+TN}{TP+FN+FN+TN}

3.2 Precision(查准率)

即所有预测为正的样本中,预测正确的样本的所占的比重。

Precision = \frac{TP}{TP+FP}

3.3  Recall(查全率)

真实的为正的样本,被正确检测出来的比重。

Recall=\frac{TP}{TP+FN}

3.4 Specificity(特异性)

特异性指标,也称 负正类率(False Positive Rate, FPR),计算的是模型错识别为正类的负类样本占所有负类样本的比例,一般越低越好。

FPR = \frac{FP}{TN+FP}

3.5 DSC(Dice coefficient)

Dice系数,是一种相似性度量,度量二进制图像分割的准确性。

如图所示红色的框的区域时Groudtruth,而蓝色的框为预测值Prediction。

DSC=\frac{2\left | G\sqcap P \right |}{\left | p \right |+\left | G \right |}

3.6 IoU(交并比)

IoU=\frac{p\sqcap G}{p\bigsqcup G}

3.7 Sensitivity(灵敏度)

反应的时预测正确的区域在Groundtruth中所占的比重。

SEN=\frac{\left | p \left | \sqcap \right |g\right | }{\left | G \right | }

4. 计算程序

ConfusionMatrix 这个类可以直接计算出混淆矩阵

from collections import defaultdict, deque
import datetime
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import errno
import osclass SmoothedValue(object):"""Track a series of values and provide access to smoothed values over awindow or the global series average."""def __init__(self, window_size=20, fmt=None):if fmt is None:fmt = "{value:.4f} ({global_avg:.4f})"self.deque = deque(maxlen=window_size)self.total = 0.0self.count = 0self.fmt = fmtdef update(self, value, n=1):self.deque.append(value)self.count += nself.total += value * ndef synchronize_between_processes(self):"""Warning: does not synchronize the deque!"""if not is_dist_avail_and_initialized():returnt = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')dist.barrier()dist.all_reduce(t)t = t.tolist()self.count = int(t[0])self.total = t[1]@propertydef median(self):d = torch.tensor(list(self.deque))return d.median().item()@propertydef avg(self):d = torch.tensor(list(self.deque), dtype=torch.float32)return d.mean().item()@propertydef global_avg(self):return self.total / self.count@propertydef max(self):return max(self.deque)@propertydef value(self):return self.deque[-1]def __str__(self):return self.fmt.format(median=self.median,avg=self.avg,global_avg=self.global_avg,max=self.max,value=self.value)class ConfusionMatrix(object):def __init__(self, num_classes):self.num_classes = num_classesself.mat = Nonedef update(self, a, b):n = self.num_classesif self.mat is None:# 创建混淆矩阵self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)with torch.no_grad():# 寻找GT中为目标的像素索引k = (a >= 0) & (a < n)# 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙)inds = n * a[k].to(torch.int64) + b[k]self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)def reset(self):if self.mat is not None:self.mat.zero_()def compute(self):h = self.mat.float()# 计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)acc_global = torch.diag(h).sum() / h.sum()# 计算每个类别的准确率acc = torch.diag(h) / h.sum(1)# 计算每个类别预测与真实目标的iouiu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))return acc_global, acc, iudef reduce_from_all_processes(self):if not torch.distributed.is_available():returnif not torch.distributed.is_initialized():returntorch.distributed.barrier()torch.distributed.all_reduce(self.mat)def __str__(self):acc_global, acc, iu = self.compute()return ('global correct: {:.1f}\n''average row correct: {}\n''IoU: {}\n''mean IoU: {:.1f}').format(acc_global.item() * 100,['{:.1f}'.format(i) for i in (acc * 100).tolist()],['{:.1f}'.format(i) for i in (iu * 100).tolist()],iu.mean().item() * 100)class DiceCoefficient(object):def __init__(self, num_classes: int = 2, ignore_index: int = -100):self.cumulative_dice = Noneself.num_classes = num_classesself.ignore_index = ignore_indexself.count = Nonedef update(self, pred, target):if self.cumulative_dice is None:self.cumulative_dice = torch.zeros(1, dtype=pred.dtype, device=pred.device)if self.count is None:self.count = torch.zeros(1, dtype=pred.dtype, device=pred.device)# compute the Dice score, ignoring backgroundpred = F.one_hot(pred.argmax(dim=1), self.num_classes).permute(0, 3, 1, 2).float()dice_target = build_target(target, self.num_classes, self.ignore_index)self.cumulative_dice += multiclass_dice_coeff(pred[:, 1:], dice_target[:, 1:], ignore_index=self.ignore_index)self.count += 1@propertydef value(self):if self.count == 0:return 0else:return self.cumulative_dice / self.countdef reset(self):if self.cumulative_dice is not None:self.cumulative_dice.zero_()if self.count is not None:self.count.zeros_()def reduce_from_all_processes(self):if not torch.distributed.is_available():returnif not torch.distributed.is_initialized():returntorch.distributed.barrier()torch.distributed.all_reduce(self.cumulative_dice)torch.distributed.all_reduce(self.count)

分类指标的计算

import torch# SR : Segmentation Result
# GT : Ground Truthdef get_accuracy(SR,GT,threshold=0.5):SR = SR > thresholdGT = GT == torch.max(GT)corr = torch.sum(SR==GT)tensor_size = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3)acc = float(corr)/float(tensor_size)return accdef get_sensitivity(SR,GT,threshold=0.5):# Sensitivity == RecallSR = SR > thresholdGT = GT == torch.max(GT)# TP : True Positive# FN : False NegativeTP = ((SR==1)+(GT==1))==2FN = ((SR==0)+(GT==1))==2SE = float(torch.sum(TP))/(float(torch.sum(TP+FN)) + 1e-6)     return SEdef get_specificity(SR,GT,threshold=0.5):SR = SR > thresholdGT = GT == torch.max(GT)# TN : True Negative# FP : False PositiveTN = ((SR==0)+(GT==0))==2FP = ((SR==1)+(GT==0))==2SP = float(torch.sum(TN))/(float(torch.sum(TN+FP)) + 1e-6)return SPdef get_precision(SR,GT,threshold=0.5):SR = SR > thresholdGT = GT == torch.max(GT)# TP : True Positive# FP : False PositiveTP = ((SR==1)+(GT==1))==2FP = ((SR==1)+(GT==0))==2PC = float(torch.sum(TP))/(float(torch.sum(TP+FP)) + 1e-6)return PCdef get_F1(SR,GT,threshold=0.5):# Sensitivity == RecallSE = get_sensitivity(SR,GT,threshold=threshold)PC = get_precision(SR,GT,threshold=threshold)F1 = 2*SE*PC/(SE+PC + 1e-6)return F1def get_JS(SR,GT,threshold=0.5):# JS : Jaccard similaritySR = SR > thresholdGT = GT == torch.max(GT)Inter = torch.sum((SR+GT)==2)Union = torch.sum((SR+GT)>=1)JS = float(Inter)/(float(Union) + 1e-6)return JSdef get_DC(SR,GT,threshold=0.5):# DC : Dice CoefficientSR = SR > thresholdGT = GT == torch.max(GT)Inter = torch.sum((SR+GT)==2)DC = float(2*Inter)/(float(torch.sum(SR)+torch.sum(GT)) + 1e-6)return DC

参考文献:

混淆矩阵的概念-CSDN博客

http://www.lryc.cn/news/288983.html

相关文章:

  • ChatGPT目前的AI一哥
  • 认识思维之熵
  • 蓝桥杯备战——1.点亮LED灯
  • 【网络协议测试】畸形数据包——圣诞树攻击(DOS攻击)
  • Java基础面试题-5day
  • 软通智慧启动鲲鹏原生应用开发合作
  • 【STM32】STM32F4中USB的CDC虚拟串口(VCP)使用方法
  • 网络协议与攻击模拟_06攻击模拟SYN Flood
  • CPU,内存和硬盘之间的关系
  • Java面试题之基础篇
  • Bitbucket第一次代码仓库创建/提交/创建新分支/合并分支/忽略ignore
  • c#反射用法
  • WPF行为
  • N-141基于springboot,vue网上拍卖平台
  • Unity之Cinemachine教程
  • java面面试面经(面试过程)
  • 大语言模型-大模型基础文献
  • 【RH850U2A芯片】Reset Vector和Interrupt Vector介绍
  • Zabbix交换分区使用率过高排查
  • ‘HEAD‘ 是 HTTP 请求的一种方法
  • go语言(十七)----json
  • Java笔记 --- 四、异常
  • Ubuntu20.04配置grub ,不必每次都输入 nomodeset
  • PBM模型学习(七)核化模型
  • 蓝桥小白赛4 乘飞机 抽屉原理 枚举
  • HTML新手教程
  • P1226 【模板】快速幂题解
  • 文旅游戏的多元应用场景
  • 小波变化最通俗的解释,小波变换是用来干什么的,类似小波变换功能的算法有哪些?
  • Servlet 与 MVC