当前位置: 首页 > news >正文

用sklearn运行分类模型,选择AUC最高的模型保存模型权重并绘制AUCROC曲线(以逻辑回归、随机森林、梯度提升、MLP为例)

诸神缄默不语-个人CSDN博文目录

文章目录

  • 1. 导入包
  • 2. 初始化分类模型
  • 3. 训练、测试模型,绘图,保存指标

1. 导入包

from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import roc_auc_score,accuracy_score,roc_curve,auc
import joblib
import matplotlib.pyplot as plt

2. 初始化分类模型

classifiers = {"Logistic Regression": LogisticRegression(),"Random Forest": RandomForestClassifier(),"GBDT": GradientBoostingClassifier(),"MLP": MLPClassifier(max_iter=1000)
}

3. 训练、测试模型,绘图,保存指标

在这里省略了数据处理部分,总之X/Y都是np.ndarray对象。f反正你创建一个可写的文件流就行,如果连这个都不会的话参考我写的这篇博文:Python3对象序列化,即处理JSON、XML和文件(持续更新ing…)。
f.close()没写,根据你的需要如果想加就加。

这个逻辑是每次得到AUC最高的模型就画图,其实感觉把模型权重储存下来然后再joblib.load()再画图会更合适……
如果想对每个模型画ROC曲线叠在一张图上的话,在最前面新建画布(plt.figure()),每个模型运行完后都运行一次plt.plot(),不close()就行。

max_auc = 0
max_acc = 0
best_classifier = ""
# 训练模型
for lr_name, lr in classifiers.items():lr.fit(X_train, y_train)# 预测y_pred = lr.predict(X_test)y_pred_proba = lr.predict_proba(X_test)[:, 1]# 评估auc_score = roc_auc_score(y_test, y_pred_proba)acc = accuracy_score(y_test, y_pred)if auc_score > max_auc:max_auc = auc_scoremax_acc = accbest_classifier = lr_namejoblib.dump(lr, f"model.pkl")fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)roc_auc = auc(fpr, tpr)plt.figure()plt.plot(fpr,tpr,color="darkorange",lw=2,label=f"ROC curve (AUC = {roc_auc:.2f})",)plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")  # 随机猜测基线plt.xlim([0.0, 1.0])plt.ylim([0.0, 1.05])plt.xlabel("False Positive Rate")plt.ylabel("True Positive Rate")plt.title("Receiver Operating Characteristic")plt.legend(loc="lower right")plt.grid()plt.savefig("roc.png")plt.close()f.write(f"{lr_name} AUC: {auc_score:.4f}, ACC: {acc:.4f}"+ "\n")f.flush()f.write(f"best_classifier: {best_classifier} AUC: {max_auc:.4f}, ACC: {max_acc:.4f}"+ "\n"
)
f.flush()
http://www.lryc.cn/news/523552.html

相关文章:

  • 动手学大数据-3社区开源实践
  • 使用Pydantic驾驭大模型
  • 【HarmonyOS之旅】基于ArkTS开发(二) -> UI开发之常见布局
  • 【论文投稿】Python 网络爬虫:探秘网页数据抓取的奇妙世界
  • 队列的基本用法
  • 网络安全VS数据安全
  • Linux(NFS服务)
  • python编程-OpenCV(图像读写-图像处理-图像滤波-角点检测-边缘检测)边缘检测
  • SSM课设-学生管理系统
  • 【Pytorch实用教程】TCN(Temporal Convolutional Network,时序卷积网络)简介
  • 网络安全 | 什么是正向代理和反向代理?
  • 3 前端(中):JavaScript
  • VIT论文阅读与理解
  • JavaScript笔记APIs篇01——DOM获取与属性操作
  • SQL表间关联查询详解
  • select函数
  • 建造者模式(或者称为生成器(构建器)模式)
  • 【深度学习】Huber Loss详解
  • A5.Springboot-LLama3.2服务自动化构建(二)——Jenkins流水线构建配置初始化设置
  • 李宏毅机器学习HW1: COVID-19 Cases Prediction
  • MySQL下载安装DataGrip可视化工具
  • 多平台下Informatica在医疗数据抽取中的应用
  • 用公网服务器实现内网穿透
  • 为什么mysql更改表结构时,varchar超过255会锁表
  • ASP.NET Core中 JWT 实现无感刷新Token
  • 函数(函数的概念、库函数、自定义函数、形参和实参、return语句、数组做函数参数、嵌套调用和链式访问、函数的声明和定义、static和extern)
  • 物联网在烟草行业的应用
  • 第6章:Python TDD实例变量私有化探索
  • Java操作Excel导入导出——POI、Hutool、EasyExcel
  • BUUCTF_Web([GYCTF2020]Ezsqli)