当前位置: 首页 > news >正文

XGB-11:随机森林

XGBoost通常用于训练梯度提升决策树和其他梯度提升模型。随机森林使用与梯度提升决策树相同的模型表示和推断,但使用不同的训练算法。可以使用XGBoost来训练独立的随机森林,或者将随机森林作为梯度提升的基模型。这里我们专注于训练独立的随机森林。

XGB从早期开始就有用于训练随机森林的API,而Scikit-Learn在0.82版本之后才有封装。

使用XGBoost API训练独立的随机森林

要启用随机森林训练,必须设置以下参数:

  • booster 应设置为 gbtree,因为正在训练森林。由于这是默认值,通常不需要显式设置此参数。

  • subsample 必须设置为小于 1 的值,以启用对训练样本(行)的随机选择。

  • colsample_by 参数之一必须设置为小于 1 的值,以启用对列的随机选择。通常,colsample_bynode 应设置为小于 1 的值,以在每次树分裂时随机抽样列。

  • num_parallel_tree 应设置为正在训练的森林的大小。

  • num_boost_round 应设置为 1,以防止 XGBoost 提升多个随机森林。请注意,这是train() 的关键字参数,不是参数字典的一部分。

  • 在训练随机森林回归时,应将 eta(别名:learning_rate)设置为 1。

  • random_state 可以用于设置随机数生成器的种子。

其他参数应以类似于梯度提升时设置的方式进行设置。例如,对于回归任务,objective 通常将设置为 reg:squarederror,而对于分类任务,将设置为 binary:logisticlambda 应根据所需的正则化权重进行设置,等等。

如果 num_parallel_treenum_boost_round 都大于 1,则训练将使用随机森林和梯度提升策略的组合。它将执行 num_boost_round 轮,在每一轮中提升 num_parallel_tree 棵树的随机森林。如果未启用提前停止,最终模型将由 num_parallel_tree * num_boost_round 棵树组成。

以下是在 GPU 上使用 xgboost 训练随机森林的示例参数字典:

params = {"colsample_bynode": 0.8,"learning_rate": 1,"max_depth": 5,"num_parallel_tree": 100,"objective": "binary:logistic","subsample": 0.8,"tree_method": "hist","device": "cuda",
}

然后可以按如下方式训练随机森林模型:

bst = train(params, dmatrix, num_boost_round=1)
import xgboost as xgb
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_errordiabetes = load_diabetes()
X = diabetes.data
y = diabetes.target# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# Create a DMatrix for XGBoost
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)# Set parameters for random forest training
params = {"booster": "gbtree","subsample": 0.8,"colsample_bynode": 0.8,"num_parallel_tree": 100,"num_boost_round": 1,"eta": 1,"random_state": 42,"objective": "reg:squarederror",
}# Train the random forest model
model = xgb.train(params, dtrain)# Make predictions on the test set
y_pred = model.predict(dtest)# Evaluate the model
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}")

基于 Scikit-Learn-Like API 实现随机森林

XGBRFClassifierXGBRFRegressor 是类似于 Scikit-Learn 的类,提供了随机森林的功能。 它们基本上是 XGBClassifierXGBRegressor 的版本,用于训练随机森林而不是梯度提升, 并相应地调整了一些参数的默认值和含义。具体来说:

  • n_estimators 指定要训练的森林的大小;它被转换为 num_parallel_tree,而不是 boosting 轮数的数量
  • learning_rate 默认设置为 1
  • colsample_bynodesubsample 默认设置为 0.8
  • booster 始终为 gbtree

例如,可以使用以下代码训练一个随机森林回归器:

from sklearn.model_selection import KFold# Your code ...kf = KFold(n_splits=2)
for train_index, test_index in kf.split(X, y):xgb_model = xgb.XGBRFRegressor(random_state=42).fit(X[train_index], y[train_index])

注意,与使用 train() 相比,这些类的参数选择较少。特别是,使用此 API 无法将随机森林与梯度提升结合起来。

import xgboost as xgb
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from xgboost import XGBRFRegressor
from sklearn.model_selection import KFolddiabetes = load_diabetes()
X = diabetes.data
y = diabetes.targetkf = KFold(n_splits=2)
for train_index, test_index in kf.split(X, y):xgb_model = xgb.XGBRFRegressor(random_state=42).fit(X[train_index], y[train_index])# Make predictions on the test set
y_pred = xgb_model.predict(X_test)# Evaluate the model
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}")

注意事项

  • XGBoost 使用二阶逼近来近似目标函数。这可能导致与使用目标函数的精确值的随机森林实现不同的结果
  • 在子采样训练样本时,XGBoost 不执行替换操作。每个训练案例在子采样集中可能出现 0 次或 1 次

参考

  • https://xgboost.readthedocs.io/en/latest/tutorials/rf.html
http://www.lryc.cn/news/305161.html

相关文章:

  • 超平面介绍
  • 【苍穹外卖】一些开发总结
  • Python 3 中,`asynchat`异步通信
  • RAW 编程接口 TCP 简介
  • Oracle EBS FA折旧回滚的分录追溯
  • sql注入 [极客大挑战 2019]FinalSQL1
  • 持续集成,持续交付和持续部署的概念,以及GitLab CI / CD的介绍
  • [Java 项目亮点] 三层限流设计
  • GPT-SoVITS 快速声音克隆使用案例:webui、api接口
  • 高速自动驾驶智慧匝道(HIC)系统功能规范
  • SQL Server——建表时为字段添加注释
  • 【明道云】导入Excel数据时的默认顺序
  • 几种后端开发中常用的语言。
  • Sora——探索AI视频模型的无限可能
  • [NCTF2019]True XML cookbook --不会编程的崽
  • Qt 应用程序中指定使用桌面版本的 OpenGL或嵌入式系统OpenGL ES的 API 进行渲染
  • 大数据软件,待补充
  • 深入探索pdfplumber:从PDF中提取信息到实际项目应用【第94篇—pdfplumbe】
  • 实现linux platform tree框架下ICM20608驱动开发(SPI)
  • 在前端开发中需要考虑的常见web安全问题和攻击原理以及防范措施
  • 年关将至送大礼 社区适时献爱心
  • singularity容器的技术基础
  • jax可微分编程的笔记(2)
  • 在Linux服务器上部署一个单机项目
  • HTTP概要
  • 128 Linux 系统编程6 ,C++程序在linux 上的调试,GDB调试
  • vue2的ElementUI的form表单报错“Error: [ElementForm]unpected width”修复
  • Linux 网络命令指南
  • vue3组件间的通信,通过props,emit,provide和inject把数据传递N个层级,expose和ref实现父组件调用子组件方法
  • 开源免费的NTFS for mac工具mounty