Python 数据建模与分析项目实战预备 Day 6 - 多模型对比与交叉验证验证策略
✅ 今日目标
- 引入多种常见分类模型(随机森林、支持向量机、K近邻等)
- 比较不同模型的训练效果
- 使用交叉验证提升评估稳定性
🧾 一、对比模型列表
模型 | 类名(sklearn) | 适用说明 |
---|---|---|
逻辑回归 | LogisticRegression | 基础线、易于解释 |
KNN | KNeighborsClassifier | 基于邻近数据点 |
决策树 | DecisionTreeClassifier | 可视化,易过拟合 |
随机森林 | RandomForestClassifier | 综合表现较优,抗过拟合 |
支持向量机 | SVC | 高维表现好,耗时较久(适合小数据) |
🧪 二、交叉验证策略
使用 cross_val_score
进行 K 折交叉验证,常用 cv=5
:
from sklearn.model_selection import cross_val_score
scores = cross_val_score(model, X, y, cv=5, scoring="accuracy")
还可以比较不同模型的:
accuracy
roc_auc
f1_macro
等指标
🧪 今日练习任务
编写脚本 model_compare_cv.py
,实现:
-
加载
processed_X_train.csv
与标签 -
初始化多个模型
-
对每个模型进行 5 折交叉验证
-
输出每个模型的平均准确率和 AUC
# model_compare_cv.py - 多模型比较与交叉验证(优化版)import pandas as pd import numpy as np import warnings from sklearn.model_selection import cross_val_score from sklearn.preprocessing import StandardScaler from sklearn.linear_model import LogisticRegression from sklearn.neighbors import KNeighborsClassifier from sklearn.tree import DecisionTreeClassifier from sklearn.ensemble import RandomForestClassifier from sklearn.svm import SVC# 忽略数值计算警告 warnings.filterwarnings("ignore", category=RuntimeWarning)# 读取训练数据 X = pd.read_csv("./data/stage4/processed_X_train.csv") y = pd.read_csv("./data/stage4/processed_y_train.csv").values.ravel()# 检查数值问题 print("🔎 是否包含 NaN:", X.isna().sum().sum()) print("🔎 是否包含 Inf:", (~np.isfinite(X)).sum().sum()) print("🔎 特征最大值:", X.max().max()) print("🔎 特征最小值:", X.min().min())# 标准化所有特征 scaler = StandardScaler() X_scaled = pd.DataFrame(scaler.fit_transform(X), columns=X.columns)# 定义模型集合 models = {"Logistic Regression": LogisticRegression(max_iter=1000, random_state=42),"K-Nearest Neighbors": KNeighborsClassifier(),"Decision Tree": DecisionTreeClassifier(random_state=42),"Random Forest": RandomForestClassifier(random_state=42),"SVM": SVC(probability=True, random_state=42), }# 定义评价指标 scoring = ["accuracy", "roc_auc"]# 逐模型评估 for name, model in models.items():print(f"🔍 模型:{name}")for score in scoring:cv_scores = cross_val_score(model, X_scaled, y, cv=5, scoring=score)print(f" [{score}] 平均得分: {cv_scores.mean():.4f} ± {cv_scores.std():.4f}")print("-" * 40)print("✅ 所有模型交叉验证完毕。")
运行输出:
🔎 是否包含 NaN: 0 🔎 是否包含 Inf: 0 🔎 特征最大值: 1.6341648019019988 🔎 特征最小值: -1.6565987890014815 🔍 模型:Logistic Regression[accuracy] 平均得分: 0.7500 ± 0.0988[roc_auc] 平均得分: 0.8409 ± 0.0656 ---------------------------------------- 🔍 模型:K-Nearest Neighbors[accuracy] 平均得分: 0.6875 ± 0.0906[roc_auc] 平均得分: 0.7469 ± 0.0751 ---------------------------------------- 🔍 模型:Decision Tree[accuracy] 平均得分: 0.6438 ± 0.0829[roc_auc] 平均得分: 0.6442 ± 0.0829 ---------------------------------------- 🔍 模型:Random Forest[accuracy] 平均得分: 0.6937 ± 0.0696[roc_auc] 平均得分: 0.7739 ± 0.0815 ---------------------------------------- 🔍 模型:SVM[accuracy] 平均得分: 0.6813 ± 0.0914[roc_auc] 平均得分: 0.7829 ± 0.0867 ---------------------------------------- ✅ 所有模型交叉验证完毕。