当前位置: 首页 > news >正文

决策树二-泰坦尼克号幸存者

一、决策树算法基础

决策树是一种直观且易于解释的机器学习算法,通过模拟人类决策过程构建树状结构,每个内部节点代表对特征的判断,叶子节点代表最终分类结果。在实际应用中,常用的决策树算法包括 ID3、C4.5 和 CART,它们的核心差异在于特征选择指标的不同。

ID3 算法以信息增益为核心指标,信息增益越大,意味着使用该属性划分数据集获得的 “纯度提升” 越显著。但该算法对可取值数目较多的属性存在偏好,例如 “编号” 这类唯一标识属性,可能被误选为最优划分特征,导致模型泛化能力下降。

C4.5 算法为解决 ID3 的缺陷,引入信息增益率,通过 “信息增益 ÷ 属性自身熵” 的计算方式,平衡了对多取值属性的偏好,使特征选择更合理。而 CART 算法则采用基尼指数衡量数据集纯度,基尼指数越小,数据集类别越集中(纯度越高),其计算公式为(Gini(D)=1-\sum_{k=1}^{n}p_k^2\),其中(p_k)是数据集D中第k类样本的占比。

此外,决策树存在天然的过拟合风险 —— 理论上可通过不断划分节点完全分离训练数据。为缓解这一问题,需采用剪枝策略:预剪枝在树的构建过程中通过限制深度、叶子节点样本数等条件提前停止分支,实用性更强;后剪枝则在完整树构建后,根据 “自身 GINI 系数 +α× 叶子节点数量” 的损失函数判断是否剪枝,α 越大,模型越简洁但可能牺牲精度,α 越小则更侧重拟合效果。

二、泰坦尼克号幸存者预测实践

(一)数据预处理

泰坦尼克号数据集包含 891 名乘客的信息,核心目标是通过乘客的特征(如舱位等级、性别、年龄等)预测其生存状态(Survived,1 表示存活,0 表示遇难)。数据预处理步骤如下:

  1. 导入工具库与数据:使用 pandas 读取数据,sklearn 提供模型训练与评估工具,matplotlib 用于可视化。
  2. 缺失值与冗余特征处理:数据集存在部分缺失值(如 Age 仅 714 条非空值,Cabin 仅 204 条非空值),且 Name、Ticket 等特征与生存状态无直接关联,需进行删除或填充;Age 缺失值采用均值填充,确保数据完整性。
  3. 分类特征数值化:性别(Sex)、登船港口(Embarked)等文本型特征无法直接输入模型,需转换为数值 —— 将 “male” 映射为 1、“female” 映射为 0,Embarked 通过唯一值列表转换为 0、1、2 等标识。

(二)代码实现与解释

1. 数据预处理代码
# 导入必要库
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split, cross_val_score
import matplotlib.pyplot as plt# 读取数据(index_col=0指定PassengerId为索引)
data = pd.read_csv("taitanic_data.csv", index_col=0)# 查看数据基本信息(缺失值、数据类型等)
print("数据初始信息:")
print(data.info())# 1. 删除缺失值过多的列(Cabin)和无关列(Name、Ticket)
data.drop(["Cabin", "Name", "Ticket"], inplace=True, axis=1)# 2. 填充Age缺失值(均值填充)
data["Age"] = data["Age"].fillna(data["Age"].mean())# 3. 删除Embarked列的缺失值行(仅2条,影响较小)
data = data.dropna()# 4. 分类特征数值化:Embarked(三分类)
labels = data["Embarked"].unique().tolist()
data["Embarked"] = data["Embarked"].apply(lambda x: labels.index(x))# 5. 分类特征数值化:Sex(二分类)
data["Sex"] = (data["Sex"] == "male").astype("int")# 查看预处理后的数据
print("\n预处理后数据前5行:")
print(data.head())
print("\n预处理后数据信息:")
print(data.info())
2. 数据集划分

将预处理后的数据集拆分为特征(X)和目标变量(y),再按 7:3 比例划分为训练集(Xtrain、Ytrain)和测试集(Xtest、Ytest),并修正索引以避免后续报错:

# 划分特征(X)和目标变量(y):y为Survived列,X为其余列
X = data.iloc[:, data.columns != "Survived"]
y = data.iloc[:, data.columns == "Survived"]# 划分训练集与测试集(test_size=0.3表示测试集占比30%)
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, y, test_size=0.3, random_state=25)# 修正索引(避免因原索引不连续导致的问题)
for i in [Xtrain, Xtest, Ytrain, Ytest]:i.index = range(i.shape[0])# 查看划分结果
print("\n训练集前5行:")
print(Xtrain.head())
print("\n测试集前5行:")
print(Xtest.head())
3. 基础决策树模型训练与评估

使用 CART 算法(默认 criterion="gini")构建模型,通过测试集得分和 10 折交叉验证评估泛化能力:

# 1. 创建决策树实例(random_state固定随机种子,确保结果可复现)
clf = DecisionTreeClassifier(random_state=25)# 2. 训练模型(使用训练集)
clf.fit(Xtrain, Ytrain)# 3. 测试集得分(准确率)
test_score = clf.score(Xtest, Ytest)
print(f"\n测试集准确率:{test_score:.4f}")# 4. 10折交叉验证(更全面评估泛化能力)
cv_score = cross_val_score(clf, X, y, cv=10).mean()
print(f"10折交叉验证平均准确率:{cv_score:.4f}")

输出结果:测试集准确率约 0.7341,10 折交叉验证平均准确率约 0.7739,说明基础模型具备一定预测能力,但可能存在过拟合。

4. 优化决策树深度(缓解过拟合)

决策树深度是影响过拟合的关键因素 —— 深度过大易 memorize 训练数据,深度过小则欠拟合。通过遍历深度 1-10,对比训练集与交叉验证集得分,选择最优深度:

# 存储不同深度的训练集与交叉验证集得分
train_scores = []
cv_scores = []# 遍历深度1到10(criterion="entropy"改用信息熵指标)
for depth in range(1, 11):clf = DecisionTreeClassifier(random_state=20,max_depth=depth,criterion="entropy"  # 用信息熵替代基尼指数)clf.fit(Xtrain, Ytrain)# 训练集得分train_score = clf.score(Xtrain, Ytrain)# 10折交叉验证得分cv_score = cross_val_score(clf, X, y, cv=10).mean()train_scores.append(train_score)cv_scores.append(cv_score)# 输出最优得分
print(f"\n最优训练集准确率:{max(train_scores):.6f}")
print(f"最优交叉验证准确率:{max(cv_scores):.6f}")
print(f"最优深度(对应最优交叉验证得分):{cv_scores.index(max(cv_scores)) + 1}")# 可视化深度与得分关系
plt.figure(figsize=(10, 6))
plt.plot(range(1, 11), train_scores, color="red", label="训练集得分")
plt.plot(range(1, 11), cv_scores, color="blue", label="交叉验证集得分")
plt.xticks(range(1, 11))  # x轴刻度为1-10
plt.xlabel("决策树深度")
plt.ylabel("准确率")
plt.title("决策树深度对模型性能的影响")
plt.legend()  # 显示图例
plt.show()

可视化结果分析:随着深度增加,训练集得分(红色曲线)持续上升(最高约 0.9132),但交叉验证集得分(蓝色曲线)在深度为 3-5 时达到峰值(约 0.8200),之后逐渐下降,说明深度超过 5 后模型开始过拟合。因此,最优决策树深度可选择 5,此时模型在拟合效果与泛化能力间达到平衡.

http://www.lryc.cn/news/625595.html

相关文章:

  • 决策树(2)
  • FPGA入门-多路选择器
  • 决策树1.1
  • 机器学习(决策树2)
  • Leetcode 深度优先搜索 (7)
  • Python爬虫第二课:爬取HTML静态网页之《某某小说》 小说章节和内容完整版
  • 【LeetCode】3655. 区间乘法查询后的异或 II (差分/商分 + 根号算法)
  • Mybatis执行SQL流程(四)之MyBatis中JDK动态代理
  • 【HTML】3D动态凯旋门
  • Leetcode 343. 整数拆分 动态规划
  • C++入门自学Day14-- Stack和Queue的自实现(适配器)
  • 神经网络中的那些关键设计:从输入输出到参数更新
  • 面试题储备-MQ篇 3-说说你对Kafka的理解
  • 图论\dp 两题
  • 设计模式笔记_行为型_命令模式
  • 【React】事件绑定和组件基础使用
  • 从线性回归到神经网络到自注意力机制 —— 激活函数与参数的演进
  • java基础(十二)redis 日志机制以及常见问题
  • 2025年12大AI测试自动化工具
  • 多模态大模型应用落地:从图文生成到音视频交互的技术选型与实践
  • 【模块系列】STM32W25Q64
  • TDengine IDMP 运维指南(4. 使用 Docker 部署)
  • 第六天~提取Arxml中CAN物理通道信息CANChannel--Physical Channel
  • 5. Dataloader 自定义数据集制作
  • C语言基础:(十八)C语言内存函数
  • java17学习笔记-Deprecate the Applet API for Removal
  • 算法——质数筛法
  • yolov5s.onnx转rk模型以及相关使用详细教程
  • 假设检验的原理
  • python的社区互助养老系统