Python中的决策树机器学习模型简要介绍和代码示例(基于sklearn)
一、决策树定义
决策树是一种监督学习算法,可用于**分类(Classification)和回归(Regression)**任务。
它的结构类似树状结构:
- 内部节点:特征条件(如
X > 2
) - 叶子节点:输出类别或数值
- 路径:对应一系列条件组合
二、决策树的基本概念
1. 信息熵(Entropy)
衡量样本集合的不确定性:
H(D)=−∑k=1Kpklog2pk H(D) = - \sum_{k=1}^{K} p_k \log_2 p_k H(D)=−k=1∑Kpklog2pk
其中:
- DDD:样本集合
- pkp_kpk:类别 kkk 的概率
2. 信息增益(Information Gain)
衡量某特征对信息熵的降低程度:
IG(D,A)=H(D)−∑v=1V∣Dv∣∣D∣H(Dv) IG(D, A) = H(D) - \sum_{v=1}^{V} \frac{|D^v|}{|D|} H(D^v) IG(D,A)=H(D)−v=1∑V∣D∣∣Dv∣H(Dv)
- DvD^vDv:按特征 AAA 值划分的子集
- 常用于 ID3 算法
3. 信息增益率(Gain Ratio)
用于 C4.5 算法,避免信息增益偏好取值多的特征:
GainRatio(D,A)=IG(D,A)IV(A) \text{GainRatio}(D, A) = \frac{IG(D, A)}{IV(A)} GainRatio(D,A)=IV(A)IG(D,A)
- IV(A)=−∑v=1V∣Dv∣∣D∣log2∣Dv∣∣D∣IV(A) = -\sum_{v=1}^V \frac{|D^v|}{|D|} \log_2 \frac{|D^v|}{|D|}IV(A)=−∑v=1V∣D∣∣Dv∣log2∣D∣∣Dv∣
4. Gini系数(Gini Impurity)
CART 分类算法使用的分裂标准:
Gini(D)=1−∑k=1Kpk2 Gini(D) = 1 - \sum_{k=1}^{K} p_k^2 Gini(D)=1−k=1∑Kpk2
越小表示纯度越高。
5. 均方误差(MSE)
用于决策树回归:
MSE=1N∑i=1N(yi−y^)2 MSE = \frac{1}{N} \sum_{i=1}^N (y_i - \hat{y})^2 MSE=N1i=1∑N(yi−y^)2
其中 y^\hat{y}y^ 是某叶子节点上的预测值。
三、决策树的算法流程(以分类为例)
-
选择最优划分特征
- 使用信息增益 / 信息增益率 / Gini 系数
-
划分数据集
- 递归构建子树
-
停止条件
- 数据已纯净 / 特征用尽 / 达到最大深度
-
剪枝(可选)
- 预剪枝 / 后剪枝,防止过拟合
四、实际示例
我们用一个简单的例子说明:
天气 | 温度 | 湿度 | 风 | 打球 |
---|---|---|---|---|
晴 | 高 | 高 | 弱 | 否 |
晴 | 高 | 高 | 强 | 否 |
阴 | 高 | 高 | 弱 | 是 |
雨 | 中 | 高 | 弱 | 是 |
雨 | 低 | 正常 | 弱 | 是 |
雨 | 低 | 正常 | 强 | 否 |
阴 | 低 | 正常 | 强 | 是 |
目标是预测“是否打球”。
五、代码实现
示例使用 Python + scikit-learn 实现
(1)决策树分类 + 可视化
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt# 加载数据
X, y = load_iris(return_X_y=True)# 创建模型
clf = DecisionTreeClassifier(criterion='entropy', max_depth=3)
clf.fit(X, y)# 预测
y_pred = clf.predict(X[:5])
print("预测结果:", y_pred)# 可视化
plt.figure(figsize=(12, 8))
plot_tree(clf, filled=True, feature_names=load_iris().feature_names, class_names=load_iris().target_names)
plt.show()
(2)决策树回归 + 可视化
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt# 模拟数据
X = np.sort(5 * np.random.rand(80, 1), axis=0)
y = np.sin(X).ravel() + np.random.normal(0, 0.1, X.shape[0])# 模型训练
reg = DecisionTreeRegressor(max_depth=4)
reg.fit(X, y)# 可视化预测
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_pred = reg.predict(X_test)plt.figure(figsize=(10, 6))
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
plt.plot(X_test, y_pred, color="cornflowerblue", label="prediction")
plt.legend()
plt.title("Decision Tree Regression")
plt.show()
六、超参数 & 控制项
参数 | 说明 |
---|---|
criterion | 划分标准(gini , entropy , squared_error ) |
max_depth | 最大深度 |
min_samples_split | 内部节点最小样本数 |
min_samples_leaf | 叶子节点最小样本数 |
max_features | 用于划分的特征数 |
七、剪枝技巧
决策树容易过拟合训练数据,特别是当树结构过深时。剪枝用于控制模型复杂度,提高泛化能力。
1、剪枝类型概览
类型 | 说明 |
---|---|
预剪枝(Pre-Pruning) | 在构建过程中限制深度、叶子样本数等 |
后剪枝(Post-Pruning) | 构建完决策树后,自底向上裁剪不必要的分支 |
本内容聚焦在 后剪枝实现。
2、后剪枝算法思想
基本流程:
- 从叶子节点向上回溯每个子树
- 判断当前子树(有分支) vs 将其剪成叶子节点谁的准确率更高(或损失更小)
- 若剪枝后效果更好 ⇒ 剪枝(用子树上样本的多数类替代)
3、后剪枝代码实现(基于自己实现的决策树)
下面是一个简洁的 ID3 决策树分类 + 后剪枝的完整实现:
1. 数据准备
from collections import Counter# 简化示例数据
data = [['晴', '高', '弱', '否'],['晴', '高', '强', '否'],['阴', '高', '弱', '是'],['雨', '中', '弱', '是'],['雨', '低', '弱', '是'],['雨', '低', '强', '否'],['阴', '低', '强', '是']
]# 特征名称
features = ['天气', '温度', '风']
2. 构建决策树(ID3)
def entropy(data):labels = [row[-1] for row in data]counter = Counter(labels)total = len(data)return -sum((count/total) * (count/total).bit_length() for count in counter.values())def split_dataset(data, axis, value):return [row[:axis] + row[axis+1:] for row in data if row[axis] == value]def choose_best_feature(data):base_entropy = entropy(data)best_gain = 0best_feature = -1num_features = len(data[0]) - 1for i in range(num_features):values = set(row[i] for row in data)new_entropy = 0for val in values:subset = split_dataset(data, i, val)prob = len(subset) / len(data)new_entropy += prob * entropy(subset)gain = base_entropy - new_entropyif gain > best_gain:best_gain = gainbest_feature = ireturn best_featuredef majority_class(data):labels = [row[-1] for row in data]return Counter(labels).most_common(1)[0][0]def build_tree(data, features):labels = [row[-1] for row in data]if labels.count(labels[0]) == len(labels):return labels[0]if len(features) == 0:return majority_class(data)best_feat = choose_best_feature(data)best_feat_name = features[best_feat]tree = {best_feat_name: {}}feat_values = set(row[best_feat] for row in data)sub_features = features[:best_feat] + features[best_feat+1:]for val in feat_values:subset = split_dataset(data, best_feat, val)tree[best_feat_name][val] = build_tree(subset, sub_features)return tree
3. 后剪枝实现
def classify(tree, features, sample):if not isinstance(tree, dict):return treeroot = next(iter(tree))sub_tree = tree[root]idx = features.index(root)value = sample[idx]subtree = sub_tree.get(value)if not subtree:return Nonereturn classify(subtree, features, sample)def accuracy(tree, features, data):correct = 0for row in data:if classify(tree, features, row[:-1]) == row[-1]:correct += 1return correct / len(data)def prune_tree(tree, features, data):if not isinstance(tree, dict):return treeroot = next(iter(tree))idx = features.index(root)new_tree = {root: {}}for val, subtree in tree[root].items():subset = [row for row in data if row[idx] == val]if not subset:new_tree[root][val] = subtreeelse:pruned_subtree = prune_tree(subtree, features[:idx] + features[idx+1:], split_dataset(data, idx, val))new_tree[root][val] = pruned_subtree# 尝试剪枝为单叶节点flat_labels = [row[-1] for row in data]majority = majority_class(data)# 原树精度original_acc = accuracy(new_tree, features, data)# 剪枝后精度(所有预测为多数类)pruned_acc = flat_labels.count(majority) / len(flat_labels)if pruned_acc >= original_acc:return majorityelse:return new_tree
4. 使用示例
# 构建树
tree = build_tree(data, features)
print("原始决策树:", tree)# 后剪枝
pruned = prune_tree(tree, features, data)
print("剪枝后树:", pruned)
4、剪枝效果展示(示意)
原始决策树:
{'天气': {'雨': {'风': {'弱': '是', '强': '否'}},'阴': '是','晴': {'风': {'弱': '否', '强': '否'}}
}}剪枝后树:
{'天气': {'雨': '是','阴': '是','晴': '否'
}}
八、优缺点总结
优点:
- 易理解,树结构直观
- 可处理分类与回归
- 可解释性强
缺点:
- 容易过拟合
- 对小变化敏感
- 对连续变量划分不如 ensemble 方法鲁棒