Python Day10
@浙大疏锦行 Python Day 10
内容:
- 划份数据集:
from sklearn.model_selection import train_test_split
X = data.drop(label_name, axis = 1)
y = data['label_name']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
- 训练模型:
model = xxx(random_state = 42) # 设置随机种子
model.fit(X_train, y_train) # 训练模型
model_pred = model.predict(X_test) # 预测
- 评估指标:准确率、召回率、精确率、F1分数、AUC值等等(这里根据任务的不同会有不同的指标)
代码:
import pandas as pd # 用于数据处理和分析,可处理表格数据。
import numpy as np # 用于数值计算,提供了高效的数组操作。
import matplotlib.pyplot as plt # 用于绘制各种类型的图表
import seaborn as sns # 基于matplotlib的高级绘图库,能绘制更美观的统计图形。
from sklearn.svm import SVC #支持向量机分类器
from sklearn.neighbors import KNeighborsClassifier #K近邻分类器
from sklearn.linear_model import LogisticRegression #逻辑回归分类器
import xgboost as xgb #XGBoost分类器
import lightgbm as lgb #LightGBM分类器
from sklearn.ensemble import RandomForestClassifier #随机森林分类器
from catboost import CatBoostClassifier #CatBoost分类器
from sklearn.tree import DecisionTreeClassifier #决策树分类器
from sklearn.naive_bayes import GaussianNB #高斯朴素贝叶斯分类器
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score # 用于评估分类器性能的指标
from sklearn.metrics import classification_report, confusion_matrix #用于生成分类报告和混淆矩阵
import warnings #用于忽略警告信息
warnings.filterwarnings("ignore") # 忽略所有警告信息
from sklearn.model_selection import train_test_split# 设置中文字体(解决中文显示问题)
plt.rcParams['font.sans-serif'] = ['SimHei'] # Windows系统常用黑体字体
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号data = pd.read_csv("./data/heart.csv")# 这里不需要处理离散值以及缺失值
X = data.drop(['target'], axis=1)
y = data['target']X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# SVM
model_svm = SVC(random_state=42)
model_svm.fit(X_train, y_train)
svm_pred = model_svm.predict(X_test)
svm_accuracy = accuracy_score(y_test, svm_pred)
print(f"SVM Accuracy: {svm_accuracy}")
# XGBoost
model_xgb = xgb.XGBClassifier(random_state=42)
model_xgb.fit(X_train, y_train)
xgb_pred = model_xgb.predict(X_test)
xgb_accuracy = accuracy_score(y_test, xgb_pred)
print(f"XGB Accuracy: {xgb_accuracy}")