当前位置: 首页 > news >正文

矿物分类案例(二)数据填充后使用6种模型训练

目录

一.多种数据填充方法与分类模型在矿物类型识别中的对比实验

1.代码整体功能概述

2.代码解析

1. 数据准备与配置

2. 循环处理不同填充方法的数据集

3. 多模型训练与评估

4. 结果保存

5.完整代码

3.代码特点与可改进方向

代码优点

可改进方向

总结


一.多种数据填充方法与分类模型在矿物类型识别中的对比实验

在实际数据分析中,缺失值处理是数据预处理的关键步骤之一,不同的缺失值填充策略可能会对模型性能产生显著影响。同时,选择合适的分类算法也是获得良好预测效果的重要因素。本文将解析一段用于对比不同缺失值填充方法和分类模型性能的 Python 代码,该代码针对矿物类型识别任务,系统评估了多种方案的效果。

1.代码整体功能概述

这段代码的核心目标是:

  1. 测试 6 种不同缺失值处理方法(删除不完整数据行、中位数填充、众数填充、平均值填充、线性回归预测填充、随机森林预测填充)对模型性能的影响
  2. 在每种填充方法处理后的数据集上,训练并评估 6 种经典分类模型(逻辑回归、随机森林、支持向量机、AdaBoost、高斯贝叶斯、XGBoost)
  3. 将所有实验结果以 JSON 格式保存,便于后续分析对比

通过这样的对比实验,我们可以找到针对 "矿物类型" 分类任务的最佳数据预处理 + 模型组合方案。

2.代码解析

1. 数据准备与配置

首先,我们需要定义实验中使用的数据集路径,代码通过字典结构清晰地组织了不同填充方法对应的训练集和测试集:

import pandas as pd
directory={'删除不完整数据行': ['训练集[删除不完整数据行].xlsx', '测试集[删除不完整数据行].xlsx'],'中位数填充':['训练集[中位数填充].xlsx','测试集[中位数填充].xlsx'],'众数填充':['训练集[众数填充].xlsx','测试集[众数填充].xlsx'],'平均值填充':['训练集[平均值填充].xlsx','测试集[平均值填充].xlsx'],'线性回归预测填充': ['训练集[线性回归预测填充].xlsx', '测试集[线性回归预测填充].xlsx'],'随机森林预测填充': ['训练集[随机森林预测填充].xlsx', '测试集[随机森林预测填充].xlsx']
}

这种字典结构使得代码具有良好的可扩展性,如需添加新的填充方法,只需在此字典中添加相应条目即可。

2. 循环处理不同填充方法的数据集

代码通过 for 循环遍历每种填充方法,对每种方法对应的数据集进行独立的模型训练和评估:

for filename in directory:# 读取训练集和测试集train_data=pd.read_excel(directory[filename][0])test_data=pd.read_excel(directory[filename][1])# 分割特征和标签(矿物类型为目标变量)train_x=train_data.drop('矿物类型',axis=1)train_y=train_data.矿物类型test_x=test_data.drop('矿物类型',axis=1)test_y=test_data.矿物类型result_data={}  # 用于存储当前填充方法下所有模型的结果

这段代码完成了数据加载和基本的数据分割,将特征数据(train_x, test_x)和目标变量(train_y, test_y)分离,为后续模型训练做准备。

3. 多模型训练与评估

在每个填充方法的循环中,代码依次训练 6 种分类模型,并记录它们的性能指标。这里以逻辑回归为例进行说明:

# 逻辑回归模型
from sklearn.linear_model import LogisticRegression
from sklearn import metricsLR_result={}  # 存储逻辑回归的评估结果lr=LogisticRegression()
lr.fit(train_x,train_y)  # 训练模型# 评估模型性能
self_predict=lr.predict(train_x)  # 训练集上的预测
print('LR自测:'+metrics.classification_report(train_y,self_predict))predicted=lr.predict(test_x)  # 测试集上的预测
print('LR测试:'+metrics.classification_report(test_y,predicted))# 提取关键评估指标
a=metrics.classification_report(test_y,predicted,digits=6).split()
LR_result['recall0']=a[6]
LR_result['recall1']=a[11]
LR_result['recall2']=a[16]
LR_result['recall3']=a[21]
LR_result['accuracy']=a[25]
result_data['LR_result']=LR_result

上述代码的关键步骤包括:

  • 导入模型类和评估工具
  • 初始化模型并使用训练数据拟合
  • 在训练集和测试集上进行预测,评估模型性能
  • 从分类报告中提取关键指标(各类别召回率和总体准确率)
  • 将结果存储到字典中

其他 5 种模型(随机森林、支持向量机、AdaBoost、高斯贝叶斯、XGBoost)采用完全相同的处理流程,确保评估标准的一致性,便于横向对比。

4. 结果保存

所有模型评估完成后,代码将结果保存为 JSON 文件:

import json
result={}
result['median fill']=result_data
with open(r'结果/'+filename+'result.json','w',encoding='utf-8') as file:json.dump(result,file,ensure_ascii=False,indent=4)

这种保存方式有两个优点:

  1. 结构化存储便于后续的结果分析和可视化
  2. 每个填充方法对应独立的 JSON 文件,避免结果混淆

5.完整代码

import pandas as pd
directory={'删除不完整数据行': ['训练集[删除不完整数据行].xlsx', '测试集[删除不完整数据行].xlsx'],'中位数填充':['训练集[中位数填充].xlsx','测试集[中位数填充].xlsx'],'众数填充':['训练集[众数填充].xlsx','测试集[众数填充].xlsx'],'平均值填充':['训练集[平均值填充].xlsx','测试集[平均值填充].xlsx'],'线性回归预测填充': ['训练集[线性回归预测填充].xlsx', '测试集[线性回归预测填充].xlsx'],'随机森林预测填充': ['训练集[随机森林预测填充].xlsx', '测试集[随机森林预测填充].xlsx']
}
for filename in directory:train_data=pd.read_excel(directory[filename][0])test_data=pd.read_excel(directory[filename][1])train_x=train_data.drop('矿物类型',axis=1)train_y=train_data.矿物类型test_x=test_data.drop('矿物类型',axis=1)test_y=test_data.矿物类型result_data={}#==================逻辑回归=========================from sklearn.linear_model import LogisticRegressionfrom sklearn import metricsLR_result={}lr=LogisticRegression()lr.fit(train_x,train_y)self_predict=lr.predict(train_x)print('LR自测:'+metrics.classification_report(train_y,self_predict))predicted=lr.predict(test_x)print('LR测试:'+metrics.classification_report(test_y,predicted))a=metrics.classification_report(test_y,predicted,digits=6).split()LR_result['recall0']=a[6]LR_result['recall1']=a[11]LR_result['recall2']=a[16]LR_result['recall3']=a[21]LR_result['accuracy']=a[25]result_data['LR_result']=LR_result#==================随机森林=========================from sklearn.ensemble import RandomForestClassifierRF_result={}rf=RandomForestClassifier()rf.fit(train_x,train_y)self_predict=rf.predict(train_x)print('RF自测:'+metrics.classification_report(train_y,self_predict))predicted=rf.predict(test_x)print('RF测试:'+metrics.classification_report(test_y,predicted))a=metrics.classification_report(test_y,predicted,digits=6).split()RF_result['recall0']=a[6]RF_result['recall1']=a[11]RF_result['recall2']=a[16]RF_result['recall3']=a[21]RF_result['accuracy']=a[25]result_data['RF_result']=RF_result#==================支持向量机=========================from sklearn.svm import SVCSVM_result={}svm=SVC()svm.fit(train_x,train_y)self_predict=svm.predict(train_x)print('SVM自测:'+metrics.classification_report(train_y,self_predict))predicted=svm.predict(test_x)print('SVM测试:'+metrics.classification_report(test_y,predicted))a=metrics.classification_report(test_y,predicted,digits=6).split()SVM_result['recall0']=a[6]SVM_result['recall1']=a[11]SVM_result['recall2']=a[16]SVM_result['recall3']=a[21]SVM_result['accuracy']=a[25]result_data['SVM_result']=SVM_result#==================Adaboost=========================from sklearn.ensemble import AdaBoostClassifierAdaboost_result={}Abt=AdaBoostClassifier()Abt.fit(train_x,train_y)self_predict=Abt.predict(train_x)print('Abt自测:'+metrics.classification_report(train_y,self_predict))predicted=Abt.predict(test_x)print('Abt测试:'+metrics.classification_report(test_y,predicted))a=metrics.classification_report(test_y,predicted,digits=6).split()Adaboost_result['recall0']=a[6]Adaboost_result['recall1']=a[11]Adaboost_result['recall2']=a[16]Adaboost_result['recall3']=a[21]Adaboost_result['accuracy']=a[25]result_data['Adaboost_result']=Adaboost_result#==================高斯贝叶斯=========================from sklearn.naive_bayes import GaussianNBGs_result={}Gs=GaussianNB()Gs.fit(train_x,train_y)self_predict=Gs.predict(train_x)print('GS自测:'+metrics.classification_report(train_y,self_predict))predicted=Gs.predict(test_x)print('GS测试:'+metrics.classification_report(test_y,predicted))a=metrics.classification_report(test_y,predicted,digits=6).split()Gs_result['recall0']=a[6]Gs_result['recall1']=a[11]Gs_result['recall2']=a[16]Gs_result['recall3']=a[21]Gs_result['accuracy']=a[25]result_data['Gs_result']=Gs_result#==================XGBoost=========================需要另外pipimport xgboost as xgbXGB_result={}xgb_model=xgb.XGBClassifier()xgb_model.fit(train_x,train_y)self_predict=xgb_model.predict(train_x)print('XGB自测:'+metrics.classification_report(train_y,self_predict))predicted=xgb_model.predict(test_x)print('XGB测试:'+metrics.classification_report(test_y,predicted))a=metrics.classification_report(test_y,predicted,digits=6).split()XGB_result['recall0']=a[6]XGB_result['recall1']=a[11]XGB_result['recall2']=a[16]XGB_result['recall3']=a[21]XGB_result['accuracy']=a[25]result_data['XGB_result']=XGB_result#保存为json文件import jsonresult={}result['median fill']=result_datawith open(r'结果/'+filename+'result.json','w',encoding='utf-8') as file:# 使用json的dump()方法将字典转化为JSON格式并写入文件,JSON一般是字典json.dump(result,file,ensure_ascii=False,indent=4)

3.代码特点与可改进方向

代码优点

  1. 模块化设计:每种模型的处理逻辑一致,便于理解和维护
  2. 可扩展性强:新增填充方法或模型只需添加对应模块
  3. 结果标准化:统一的指标提取和保存格式,便于对比分析

可改进方向

  1. 添加特征预处理:对逻辑回归、SVM 等对特征尺度敏感的模型,应增加标准化 / 归一化步骤
  2. 模型参数调优:当前使用模型默认参数,可引入网格搜索进行超参数优化
  3. 完善评估指标:可增加 precision、F1-score 等更多评估指标
  4. 异常处理:增加 try-except 块处理可能的文件读取或模型训练错误
  5. 结果可视化:添加自动生成对比图表的功能,直观展示各方案优劣
  6. 交叉验证:对训练过程引入交叉验证,提高结果可靠性

总结

这段代码通过系统化的实验设计,全面对比了不同缺失值处理方法和分类模型在矿物类型识别任务上的表现。实验结果能够帮助我们选择最优的数据预处理策略和分类算法,为实际应用提供指导。

http://www.lryc.cn/news/624915.html

相关文章:

  • Android中flavor的使用
  • PostgreSQL中的json_agg()
  • 初始向量数据库之Milvus
  • milvus如何存储特殊类型的数据
  • Milvus向量数据库安装步骤
  • 大厂 | 华为半导体业务部2026届秋招启动
  • 【大模型】RAG
  • 基于nvm安装管理多个node.js版本切换使用(附上详细安装使用图文教程+nvm命令大全)
  • ANSI终端色彩控制知识散播(I):语法封装(Python)——《彩色终端》诗评
  • 楼宇自控系统深化设计需关注哪些核心要点?技术与应用解析
  • 第一阶段C#-14:委托,事件
  • ReactNative开发实战——React Native开发环境配置指南
  • 机器翻译论文阅读方法:顶会(ACL、EMNLP)论文解析技巧
  • ADC的实现(单通道,多通道,DMA)
  • 如何编写自己的Spring容器
  • 【EI会议征稿】2025第四届健康大数据与智能医疗国际会议(ICHIH 2025)
  • VS Code Copilot 完整使用教程(含图解)
  • 全局锁应用场景理解
  • 深度学习——R-CNN及其变体
  • 04 类型别名type + 检测数据类型(typeof+instanceof) + 空安全+剩余和展开(运算符 ...)简单类型和复杂类型 + 模块化
  • Spark 运行流程核心组件(三)任务执行
  • 实习两个月总结
  • [系统架构设计师]软件架构的演化与维护(十)
  • SpringBoot--JWT
  • 大数据计算引擎(四)—— Impala
  • React diff——差异协调算法简介
  • 深入解析 Qwen3 GSPO:一种稳定高效的大语言模型强化学习算法
  • 整体设计 之“凝聚式中心点”原型 --整除:智能合约和DBMS的深层融合 之2
  • LLM - MCP传输协议解读:从SSE的单向奔赴到Streamable HTTP的双向融合
  • 【软考架构】第4章 信息安全的抗攻击技术