当前位置: 首页 > news >正文

《机器学习核心技术》分类算法 - 决策树

「作者主页」:士别三日wyx
「作者简介」:CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者
「推荐专栏」:小白零基础《Python入门到精通》

在这里插入图片描述

决策树

  • 1、决策树API
  • 2、决策时实际应用
    • 2.1、获取数据集
    • 2.2、划分数据集
    • 2.3、决策树处理
    • 2.4、模型评估

决策树是一种 「二叉树形式」的预测模型,每个 「节点」对应一个 「判断条件」「满足」上一个条件才能 「进入下一个」判断条件。

就比如找对象,第一个条件肯定是长得帅,长得帅的才考虑下一个条件;长得不帅就直接pass,不往下考虑了。

在这里插入图片描述

决策树的「核心」在于:如何找到「最高效」「决策顺序」

1、决策树API

sklearn.tree.DecisionTreeClassifier() 是决策树分类算法的API

参数

  • criterion:(可选)衡量分裂的质量,可选值有ginientropylog_loss,默认值 gini
  • splitter:(可选)给每个节点选择分割的策略,可选值有bestrandom,默认值 best
  • max_depth:(可选)树的最大深度,默认值 None
  • min_samples_split:(可选)分割节点所需要的的最小样本数,默认值 2
  • min_samples_leaf:(可选)叶节点上所需要的的最小样本数,默认值 1
  • min_weight_fraction_leaf:(可选)叶节点的权重总和的最小加权分数,默认值 0.0
  • max_features:(可选)寻找最佳分割时要考虑的特征数量,默认值 None
  • random_state:(可选)控制分裂特征的随机数,默认值 None
  • max_leaf_nodes:(可选)最大叶子节点数,默认值 None
  • min_impurity_decrease:(可选)如果分裂指标的减少量大于该值,就进行分裂,默认值 0.0
  • class_weight:(可选)每个类的权重,默认值 None
  • ccp_alpha:(可选)将选择成本复杂度最大且小于ccp_alpha的子树。默认情况下,不执行修剪。

函数

  • fit( x_train, y_train ):接收训练集特征 和 训练集目标
  • predict( x_test ):接收测试集特征,返回数据的类标签。
  • score( x_test, y_test ):接收测试集特征 和 测试集目标,返回准确率。
  • predict_log_proba():预测样本的类对数概率

属性

  • classes_:类标签
  • feature_importances_:特征的重要性
  • max_features_:最大特征推断值
  • n_classes_:类的数量
  • n_features_in_:特征数
  • feature_names_in_:特征名称
  • n_outputs_:输出的数量
  • tree_:底层的tree对象

2、决策时实际应用

2.1、获取数据集

这里使用sklearn自带的鸢尾花数据集进行演示。

from sklearn import datasets# 1、获取数据集
iris = datasets.load_iris()

2.2、划分数据集

传入数据集的特征值和目标值,按照默认的比例划分数据集。

from sklearn import datasets
from sklearn import model_selection# 1、获取数据集
iris = datasets.load_iris()
# # 2、划分数据集
x_train, x_test, y_train, y_test = model_selection.train_test_split(iris.data, iris.target)

2.3、决策树处理

实例化对象,传入训练集特征值和目标值,开始训练。

from sklearn import datasets
from sklearn import model_selection
from sklearn import tree# 1、获取数据集
iris = datasets.load_iris()
# # 2、划分数据集
x_train, x_test, y_train, y_test = model_selection.train_test_split(iris.data, iris.target)
# # 3、决策树处理
estimator = tree.DecisionTreeClassifier()
estimator.fit(x_train, y_train)

2.4、模型评估

对比测试集,验证准确率。

from sklearn import datasets
from sklearn import model_selection
from sklearn import tree# 1、获取数据集
iris = datasets.load_iris()
# # 2、划分数据集
x_train, x_test, y_train, y_test = model_selection.train_test_split(iris.data, iris.target)
# # 3、决策树处理
estimator = tree.DecisionTreeClassifier()
estimator.fit(x_train, y_train)
# # 4、模型评估
y_predict = estimator.predict(x_test)
print('对比真实值和预测值', y_test == y_predict)
score = estimator.score(x_test, y_test)
print('准确率:', score)

输出:

对比真实值和预测值 [ True  True  True  True  True False  True  True  True  True  True  TrueFalse  True  True  True  True  True  True  True  True  True  True  TrueTrue  True  True  True  True  True  True  True  True  True  True  TrueTrue  True]
准确率: 0.9473684210526315

从结果可以看到,准确率达到了94%

http://www.lryc.cn/news/142802.html

相关文章:

  • aws PinPoint发附件demo
  • 边写代码边学习之Bidirectional LSTM
  • Django学习笔记-实现联机对战
  • nacos总结1
  • Web安全测试(三):SQL注入漏洞
  • Webstorm 入门级玩转uni-app 项目-微信小程序+移动端项目方案
  • 从零开始的Hadoop学习(三)| 集群分发脚本xsync
  • golang http transport源码分析
  • spring boot 项目整合 websocket
  • 统计学补充概念-17-线性决策边界
  • 指针变量、指针常量与常量指针的区别
  • mq与mqtt的关系
  • 代码大全阅读随笔 (二)
  • vue 项目的屏幕自适应方案
  • 23软件测试高频率面试题汇总
  • PHP8的匿名函数-PHP8知识详解
  • Redis—Redis介绍(是什么/为什么快/为什么做MySQL缓存等)
  • C语言链表梳理-2
  • 【深度学习】实验03 特征处理
  • 基于Dpabi的功能连接
  • 在React项目是如何捕获错误的?
  • 基于内存池的 简单高效的数据库 SDK简介
  • python实例方法,类方法和静态方法区别
  • Pyecharts教程(四):使用pyecharts绘制3D折线图
  • 【stable-diffusion使用扩展+插件和模型资源(下)】
  • 一文了解SpringBoot中的Aop
  • android系统启动流程之zygote如何创建SystemServer进程
  • 【awd系列】Bugku S3 AWD排位赛-9 pwn类型
  • vcomp140.dll丢失的修复方法分享,电脑提示vcomp140.dll丢失修复方法
  • Docker file解析