当前位置: 首页 > news >正文

技巧|SwanLab记录混淆矩阵攻略

绘制混淆矩阵(Confusion Matrix),用于评估分类模型的性能。混淆矩阵展示了模型预测结果与真实标签之间的对应关系,能够直观地显示各类别的预测准确性和错误类型。

混淆矩阵是评估分类模型性能的基础工具,特别适用于多分类问题。

你可以使用swanlab.confusion_matrix来记录混淆矩阵。

Demo链接:ComputeMetrics - SwanLab

在这里插入图片描述

基本用法

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import xgboost as xgb
import swanlab# 加载鸢尾花数据集
iris_data = load_iris()
X = iris_data.data
y = iris_data.target
class_names = iris_data.target_names.tolist()# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 训练模型
model = xgb.XGBClassifier(objective='multi:softmax', num_class=len(class_names))
model.fit(X_train, y_train)# 获取预测结果
y_pred = model.predict(X_test)# 初始化SwanLab
swanlab.init(project="Confusion-Matrix-Demo", experiment_name="Confusion-Matrix-Example")# 记录混淆矩阵
swanlab.log({"confusion_matrix": swanlab.confusion_matrix(y_test, y_pred, class_names)
})swanlab.finish()

使用自定义类别名称

# 定义自定义类别名称
custom_class_names = ["类别A", "类别B", "类别C"]# 记录混淆矩阵
confusion_matrix = swanlab.confusion_matrix(y_test, y_pred, custom_class_names)
swanlab.log({"confusion_matrix_custom": confusion_matrix})

不使用类别名称

# 不指定类别名称,将使用数字索引
confusion_matrix = swanlab.confusion_matrix(y_test, y_pred)
swanlab.log({"confusion_matrix_default": confusion_matrix})

二分类示例

import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import xgboost as xgb
import swanlab# 生成二分类数据
X, y = make_classification(n_samples=1000, n_features=20, n_informative=2, n_redundant=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 训练模型
model = xgb.XGBClassifier(use_label_encoder=False, eval_metric='logloss')
model.fit(X_train, y_train)# 获取预测结果
y_pred = model.predict(X_test)# 记录混淆矩阵
swanlab.log({"confusion_matrix": swanlab.confusion_matrix(y_test, y_pred, ["负类", "正类"])
})

注意事项

  1. 数据格式: y_truey_pred可以是列表或numpy数组
  2. 多分类支持: 此函数支持二分类和多分类问题
  3. 类别名称: class_names的长度应该与类别数量一致
  4. 依赖包: 需要安装scikit-learnpyecharts
  5. 坐标轴: sklearn的confusion_matrix左上角为(0,0),在pyecharts的heatmap中是左下角,函数会自动处理坐标转换
  6. 矩阵解读: 混淆矩阵中,行表示真实标签,列表示预测标签
http://www.lryc.cn/news/607859.html

相关文章:

  • 解决忘记修改配置密码而无法连接nacos的问题
  • DockerFile文件执行docker bulid自动构建镜像
  • Android 15 限制APK包手动安装但不限制自升级的实现方案
  • 20250802让飞凌OK3576-C开发板在飞凌的Android14下【rk3576_u选项】适配NXP的WIFIBT模块88W8987A的蓝牙
  • 【Android】通知
  • React ahooks——副作用类hooks之useDebounceFn
  • linux eval命令的使用方法介绍
  • 【vue】创建响应式数据ref和reactive的区别
  • 防火墙配置实验2(DHCP,用户认证,安全策略)
  • C语言---函数的递归与迭代
  • 【DL学习笔记】DL入门指南
  • 《深潜React列表渲染:调和算法与虚拟DOM Diff的优化深解》
  • 2024年网络安全案例
  • rag学习-以项目为基础快速启动掌握rag
  • 建筑施工场景安全帽识别误报率↓79%:陌讯动态融合算法实战解析
  • WordPress AI写作插件开发实战:从GPT集成到企业级部署
  • retro-go 1.45 编译及显示中文
  • 浏览器及java读取ros1的topic
  • 在 Elasticsearch 中落地 Learning to Rank(LTR)
  • sqli-labs通关笔记-第28a关GET字符注入(关键字过滤绕过 手注法)
  • 关于Web前端安全防御CSRF攻防的几点考虑
  • MFC 实现托盘图标菜单图标功能
  • 【相机】曝光时间长-->拖影
  • Effective C++ 条款17:以独立语句将newed对象置入智能指针
  • 易华路副总经理兼交付管理中心部门经理于江平受邀PMO大会主持人
  • Elasticsearch+Logstash+Filebeat+Kibana单机部署
  • RabbitMQ面试精讲 Day 7:消息持久化与过期策略
  • 用Unity结合VCC更改人物模型出现的BUG
  • 个人笔记UDP
  • 内存、硬盘与缓存的技术原理及特性解析