当前位置: 首页 > news >正文

Scikit-Learn决策树

Scikit-Learn决策树

    • 1、决策树分类
    • 2、Scikit-Learn决策树分类
      • 2.1、Scikit-Learn决策树API
      • 2.2、Scikit-Learn决策树初体验
      • 2.3、Scikit-Learn决策树实践(葡萄酒分类)



1、决策树分类


2、Scikit-Learn决策树分类

2.1、Scikit-Learn决策树API


官方文档:https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier

中文官方文档:https://scikit-learn.org.cn/view/784.html

2.2、Scikit-Learn决策树初体验


下面我们使用Scikit-Learn提供的API制作两个交错的半圆形状数据集来演示Scikit-Learn决策树

1)制作数据集

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets# 生成两个交错的半圆形状数据集
X, y = datasets.make_moons(noise=0.25, random_state=666)
plt.scatter(X[y == 0, 0], X[y == 0, 1])
plt.scatter(X[y == 1, 0], X[y == 1, 1])
plt.show()

在这里插入图片描述

2)训练决策树分类模型

from sklearn.tree import DecisionTreeClassifier      # 决策树分类器# 使用CART分类树的默认参数
dt_clf = DecisionTreeClassifier()
# dt_clf = DecisionTreeClassifier(max_depth=2, max_leaf_nodes=4)
# 训练拟合
dt_clf.fit(X, y)

3)绘制决策边界

# 绘制决策边界
decision_boundary_fill(dt_clf, axis=[-1.5, 2.5, -1.0, 1.5])
plt.scatter(X[y == 0, 0], X[y == 0, 1])
plt.scatter(X[y == 1, 0], X[y == 1, 1])
plt.show()

其中,使用到的绘制函数详见文章:传送门

当使用CART分类树的默认参数时,其决策边界如图所示:

在这里插入图片描述
由图可见,在不加限制的情况下,一棵决策树会生长到所有的叶子都是纯净的或者或者没有更多的特征可用为止。这样的决策树往往会过拟合,也就是说,它在训练集上表现的很好,而在测试集上却表现的很糟糕

当我们限制决策树的最大深度max_depth=2,并且最大叶子节点数max_leaf_nodes=4时,其决策边界如下图所示:

在这里插入图片描述
通过限制一些参数,对决策树进行剪枝,可以让我们的决策树具有更好的泛化性

2.3、Scikit-Learn决策树实践(葡萄酒分类)


2.3.1、葡萄酒数据集

葡萄酒(Wine)数据集是来自加州大学欧文分校(UCI)的公开数据集,这些数据是对意大利同一地区种植的葡萄酒进行化学分析的结果。数据集共178个样本,包括三个不同品种,每个品种的葡萄酒中含有13种成分(特征)、一个类别标签,分别使是0/1/2来代表葡萄酒的三个分类

数据集的属性信息(13特征+1标签)如下:

from sklearn.datasets import load_winewine = load_wine()
data = pd.DataFrame(data=wine.data, columns=wine.feature_names)
data['class'] = wine.target
print(data.head().to_string())
'''alcohol  malic_acid   ash  alcalinity_of_ash  magnesium  total_phenols  flavanoids  nonflavanoid_phenols  proanthocyanins  color_intensity   hue  od280/od315_of_diluted_wines  proline  class
0    14.23        1.71  2.43               15.6      127.0           2.80        3.06                  0.28             2.29             5.64  1.04                          3.92   1065.0      0
1    13.20        1.78  2.14               11.2      100.0           2.65        2.76                  0.26             1.28             4.38  1.05                          3.40   1050.0      0
2    13.16        2.36  2.67               18.6      101.0           2.80        3.24                  0.30             2.81             5.68  1.03                          3.17   1185.0      0
3    14.37        1.95  2.50               16.8      113.0           3.85        3.49                  0.24             2.18             7.80  0.86                          3.45   1480.0      0
4    13.24        2.59  2.87               21.0      118.0           2.80        2.69                  0.39             1.82             4.32  1.04                          2.93    735.0      0
'''
属性/标签说明
alcohol酒精含量(百分比)
malic_acid苹果酸含量(克/升)
ash灰分含量(克/升)
alcalinity_of_ash灰分碱度(mEq/L)
magnesium镁含量(毫克/升)
total_phenols总酚含量(毫克/升)
flavanoids类黄酮含量(毫克/升)
nonflavanoid_phenols非黄酮酚含量(毫克/升)
proanthocyanins原花青素含量(毫克/升)
color_intensity颜色强度(单位absorbance)
hue色调(在1至10之间的一个数字)
od280/od315_of_diluted_wines稀释葡萄酒样品的光密度比值,用于测量葡萄酒中各种化合物的浓度
proline脯氨酸含量(毫克/升)
class分类标签(class_0(59)、class_1(71)、class_2(48))

数据集的概要信息如下:

# 数据集大小
print(wine.data.shape)      # (178, 13)
# 标签名称
print(wine.target_names)    # ['class_0' 'class_1' 'class_2']
# 分类标签
print(data.groupby('class')['class'].count())
'''
class
0    59
1    71
2    48
Name: class, dtype: int64
'''

数据集的缺失值情况:

# 缺失值:无缺失值
print(data.isnull().sum())

在这里插入图片描述
2.3.2、决策树实践(葡萄酒分类)


未完待续…

http://www.lryc.cn/news/347928.html

相关文章:

  • Python面试题【python基础部分1-50】
  • 鸿蒙内核源码分析(Shell编辑篇) | 两个任务,三个阶段
  • 第Ⅷ章-Ⅱ 组合式API使用
  • stable-diffusion-webui配置
  • 1+X电子商务数据采集渠道及工具选择(二)||电商数据采集API接口
  • apinto OpenAPI
  • XYCTF - web
  • 学习方法的重要性
  • 把现有的 Jenkins 容器推送到一个新的镜像标签,并且重新启动新的容器
  • 难以重现的 Bug如何处理
  • 我与足球的故事 | 10年的热爱 | 伤病 | 悔恨 | 放弃 or 继续 | 小学生的碎碎念罢了
  • js图片回显的方法
  • Java中的maven的安装和配置
  • 轴承制造企业“数智化”突破口
  • UIButton案例之添加动画
  • C#链接数据库、操作sql、选择串口
  • 本地搭建各大直播平台录屏服务结合内网穿透工具实现远程管理录屏任务
  • macos使用yarn创建vite时出现Usage Error: The nearest package directory问题
  • 【JAVA入门】Day04 - 方法
  • 前端报错 SyntaxError: Unexpected number in JSON at position xxxx at JSON.parse
  • Mybatis进阶详细用法
  • Android 系统省电软件分析
  • 了解什么是Docker
  • ChatGPT开源的whisper音频生成字幕
  • 融知财经:期货和现货的区别是什么?哪个风险大?
  • Android Studio开发之路(十)app中使用aar以及报错记录
  • sql-行转列3(转置)
  • MATLAB | 最新版MATLAB绘图速查表来啦!!
  • web安全之登录框渗透骚姿势,新思路
  • 无人机+自组网:空地点对点无人机通信解决方案