当前位置: 首页 > news >正文

Python中的决策树机器学习模型简要介绍和代码示例(基于sklearn)

一、决策树定义

决策树是一种监督学习算法,可用于**分类(Classification)回归(Regression)**任务。

它的结构类似树状结构:

  • 内部节点:特征条件(如X > 2
  • 叶子节点:输出类别或数值
  • 路径:对应一系列条件组合

二、决策树的基本概念

1. 信息熵(Entropy)

衡量样本集合的不确定性:

H(D)=−∑k=1Kpklog⁡2pk H(D) = - \sum_{k=1}^{K} p_k \log_2 p_k H(D)=k=1Kpklog2pk

其中:

  • DDD:样本集合
  • pkp_kpk:类别 kkk 的概率

2. 信息增益(Information Gain)

衡量某特征对信息熵的降低程度:

IG(D,A)=H(D)−∑v=1V∣Dv∣∣D∣H(Dv) IG(D, A) = H(D) - \sum_{v=1}^{V} \frac{|D^v|}{|D|} H(D^v) IG(D,A)=H(D)v=1VDDvH(Dv)

  • DvD^vDv:按特征 AAA 值划分的子集
  • 常用于 ID3 算法

3. 信息增益率(Gain Ratio)

用于 C4.5 算法,避免信息增益偏好取值多的特征:

GainRatio(D,A)=IG(D,A)IV(A) \text{GainRatio}(D, A) = \frac{IG(D, A)}{IV(A)} GainRatio(D,A)=IV(A)IG(D,A)

  • IV(A)=−∑v=1V∣Dv∣∣D∣log⁡2∣Dv∣∣D∣IV(A) = -\sum_{v=1}^V \frac{|D^v|}{|D|} \log_2 \frac{|D^v|}{|D|}IV(A)=v=1VDDvlog2DDv

4. Gini系数(Gini Impurity)

CART 分类算法使用的分裂标准:

Gini(D)=1−∑k=1Kpk2 Gini(D) = 1 - \sum_{k=1}^{K} p_k^2 Gini(D)=1k=1Kpk2

越小表示纯度越高。


5. 均方误差(MSE)

用于决策树回归:

MSE=1N∑i=1N(yi−y^)2 MSE = \frac{1}{N} \sum_{i=1}^N (y_i - \hat{y})^2 MSE=N1i=1N(yiy^)2

其中 y^\hat{y}y^ 是某叶子节点上的预测值。


三、决策树的算法流程(以分类为例)

  1. 选择最优划分特征

    • 使用信息增益 / 信息增益率 / Gini 系数
  2. 划分数据集

    • 递归构建子树
  3. 停止条件

    • 数据已纯净 / 特征用尽 / 达到最大深度
  4. 剪枝(可选)

    • 预剪枝 / 后剪枝,防止过拟合

四、实际示例

我们用一个简单的例子说明:

天气温度湿度打球
正常
正常
正常

目标是预测“是否打球”。


五、代码实现

示例使用 Python + scikit-learn 实现

(1)决策树分类 + 可视化

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt# 加载数据
X, y = load_iris(return_X_y=True)# 创建模型
clf = DecisionTreeClassifier(criterion='entropy', max_depth=3)
clf.fit(X, y)# 预测
y_pred = clf.predict(X[:5])
print("预测结果:", y_pred)# 可视化
plt.figure(figsize=(12, 8))
plot_tree(clf, filled=True, feature_names=load_iris().feature_names, class_names=load_iris().target_names)
plt.show()

(2)决策树回归 + 可视化

from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt# 模拟数据
X = np.sort(5 * np.random.rand(80, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])# 模型训练
reg = DecisionTreeRegressor(max_depth=4)
reg.fit(X, y)# 可视化预测
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_pred = reg.predict(X_test)plt.figure(figsize=(10, 6))
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
plt.plot(X_test, y_pred, color="cornflowerblue", label="prediction")
plt.legend()
plt.title("Decision Tree Regression")
plt.show()

六、超参数 & 控制项

参数说明
criterion划分标准(gini, entropy, squared_error
max_depth最大深度
min_samples_split内部节点最小样本数
min_samples_leaf叶子节点最小样本数
max_features用于划分的特征数

七、剪枝技巧

决策树容易过拟合训练数据,特别是当树结构过深时。剪枝用于控制模型复杂度,提高泛化能力。


1、剪枝类型概览

类型说明
预剪枝(Pre-Pruning)在构建过程中限制深度、叶子样本数等
后剪枝(Post-Pruning)构建完决策树后,自底向上裁剪不必要的分支

本内容聚焦在 后剪枝实现


2、后剪枝算法思想

基本流程:

  1. 从叶子节点向上回溯每个子树
  2. 判断当前子树(有分支) vs 将其剪成叶子节点谁的准确率更高(或损失更小)
  3. 若剪枝后效果更好 ⇒ 剪枝(用子树上样本的多数类替代)

3、后剪枝代码实现(基于自己实现的决策树)

下面是一个简洁的 ID3 决策树分类 + 后剪枝的完整实现:

1. 数据准备
from collections import Counter# 简化示例数据
data = [['晴', '高', '弱', '否'],['晴', '高', '强', '否'],['阴', '高', '弱', '是'],['雨', '中', '弱', '是'],['雨', '低', '弱', '是'],['雨', '低', '强', '否'],['阴', '低', '强', '是']
]# 特征名称
features = ['天气', '温度', '风']

2. 构建决策树(ID3)
def entropy(data):labels = [row[-1] for row in data]counter = Counter(labels)total = len(data)return -sum((count/total) * (count/total).bit_length() for count in counter.values())def split_dataset(data, axis, value):return [row[:axis] + row[axis+1:] for row in data if row[axis] == value]def choose_best_feature(data):base_entropy = entropy(data)best_gain = 0best_feature = -1num_features = len(data[0]) - 1for i in range(num_features):values = set(row[i] for row in data)new_entropy = 0for val in values:subset = split_dataset(data, i, val)prob = len(subset) / len(data)new_entropy += prob * entropy(subset)gain = base_entropy - new_entropyif gain > best_gain:best_gain = gainbest_feature = ireturn best_featuredef majority_class(data):labels = [row[-1] for row in data]return Counter(labels).most_common(1)[0][0]def build_tree(data, features):labels = [row[-1] for row in data]if labels.count(labels[0]) == len(labels):return labels[0]if len(features) == 0:return majority_class(data)best_feat = choose_best_feature(data)best_feat_name = features[best_feat]tree = {best_feat_name: {}}feat_values = set(row[best_feat] for row in data)sub_features = features[:best_feat] + features[best_feat+1:]for val in feat_values:subset = split_dataset(data, best_feat, val)tree[best_feat_name][val] = build_tree(subset, sub_features)return tree

3. 后剪枝实现
def classify(tree, features, sample):if not isinstance(tree, dict):return treeroot = next(iter(tree))sub_tree = tree[root]idx = features.index(root)value = sample[idx]subtree = sub_tree.get(value)if not subtree:return Nonereturn classify(subtree, features, sample)def accuracy(tree, features, data):correct = 0for row in data:if classify(tree, features, row[:-1]) == row[-1]:correct += 1return correct / len(data)def prune_tree(tree, features, data):if not isinstance(tree, dict):return treeroot = next(iter(tree))idx = features.index(root)new_tree = {root: {}}for val, subtree in tree[root].items():subset = [row for row in data if row[idx] == val]if not subset:new_tree[root][val] = subtreeelse:pruned_subtree = prune_tree(subtree, features[:idx] + features[idx+1:], split_dataset(data, idx, val))new_tree[root][val] = pruned_subtree# 尝试剪枝为单叶节点flat_labels = [row[-1] for row in data]majority = majority_class(data)# 原树精度original_acc = accuracy(new_tree, features, data)# 剪枝后精度(所有预测为多数类)pruned_acc = flat_labels.count(majority) / len(flat_labels)if pruned_acc >= original_acc:return majorityelse:return new_tree

4. 使用示例
# 构建树
tree = build_tree(data, features)
print("原始决策树:", tree)# 后剪枝
pruned = prune_tree(tree, features, data)
print("剪枝后树:", pruned)

4、剪枝效果展示(示意)

原始决策树:
{'天气': {'雨': {'风': {'弱': '是', '强': '否'}},'阴': '是','晴': {'风': {'弱': '否', '强': '否'}}
}}剪枝后树:
{'天气': {'雨': '是','阴': '是','晴': '否'
}}

八、优缺点总结

优点:

  • 易理解,树结构直观
  • 可处理分类与回归
  • 可解释性强

缺点:

  • 容易过拟合
  • 对小变化敏感
  • 对连续变量划分不如 ensemble 方法鲁棒

http://www.lryc.cn/news/603793.html

相关文章:

  • Unity_SRP Batcher
  • 谷歌采用 Ligero 构建其 ZK 技术栈
  • 【密码学】4. 分组密码
  • ftp加ssl,升级ftps
  • WebRTC(十四):WebRTC源码编译与管理
  • 7月29日星期二今日早报简报微语报早读
  • TCPDump实战手册:协议/端口/IP过滤与组合分析指南
  • Kruskal算法
  • 《林景媚与命运共创者》
  • 暑期算法训练.10
  • Spring Boot中的this::语法糖详解
  • 解锁全球数据:Bright Data MCP 智能解决代理访问难题
  • pnpm 入门与实践指南
  • Element Plus常见基础组件(二)
  • React 图标库发布到 npm 仓库
  • Linux -- 文件【中】
  • 基于深度学习的医学图像分析:使用CycleGAN实现图像到图像的转换
  • tcp通讯学习数据传输
  • DETR 下 Transformer 应用探讨
  • 准大一GIS专业新生,如何挑选电脑?
  • 站点到站点-主模式
  • Java 11 新特性详解与代码示例
  • JAVA中集合的遍历方式
  • 【C++】1. C++基础知识
  • 编辑距离:理论基础、算法演进与跨领域应用
  • taro+react重新给userInfo赋值后,获取的用户信息还是老用户信息
  • ERROR c.a.c.n.c.NacosPropertySourceBuilder
  • react 的 useTransition 、useDeferredValue
  • react中暴露事件useImperativeHandle
  • 【C++】判断语句