当前位置: 首页 > news >正文

【机器学习第二期(Python)】优化梯度提升决策树 XGBoost

优化梯度提升决策树XGBoost

  • 📌 一、XGBoost 简介
  • 🧠 二、原理详解
    • 2.1 基础思想:改进版 GBDT
    • 2.2 目标函数
    • 2.3 二阶泰勒展开优化
    • 2.4 树结构优化
  • 🔧 三、XGBoost 实现步骤(Python)
    • 🧪 可调参数推荐
    • 完整案例代码(回归任务 + 可视化)
  • 参考

梯度提升决策树 GBDT的原理及Python代码实现可参考另一博客-【机器学习第一期(Python)】梯度提升决策树 GBDT。

XGBoost(Extreme Gradient Boosting) 是一种高效的 梯度提升决策树(Gradient Boosting Decision Tree, GBDT) 算法的优化版本。它通过迭代方式构建多个弱学习器(通常是回归树),每一棵新树都试图纠正前一棵树的预测误差。

📌 一、XGBoost 简介

XGBoost 是一种高效、灵活、可扩展的提升树算法,由陈天奇博士开发。它在多个机器学习竞赛中表现优异,是 Kaggle 等平台上最常用的模型之一。

XGBoost vs. GBDT 区别总结

特性GBDT(传统)XGBoost(优化)
损失优化一阶导数(残差)一阶 + 二阶导数(更精确)
正则化无或简单显式正则项控制模型复杂度
剪枝策略贪心构建后不剪枝支持后剪枝(loss-guided)
并行计算列块结构支持特征并行
缓存优化支持高效缓存、内存优化
缺失值处理手动处理自动学习缺失方向
自定义损失函数较难支持一阶二阶导实现定制损失

🧠 二、原理详解

2.1 基础思想:改进版 GBDT

XGBoost 本质上仍是基于梯度提升(Gradient Boosting)的思想,但它在 模型正则化、树结构构建、并行化、缓存优化 等方面做了大量改进。

2.2 目标函数

XGBoost 的目标函数包含两部分:

Obj = \sum_{i=1}^{n} l(y_i, \hat{y}_i^{(t)}) + \sum_{k=1}^{t} \Omega(f_k)

其中:
l l l:损失函数(如平方误差、对数损失等)

Ω ( f k ) = γ T + 1 2 λ ∑ w j 2 \Omega(f_k) = \gamma T + \frac{1}{2} \lambda \sum w_j^2 Ω(fk)=γT+21λwj2 为正则项

  • 控制模型复杂度,防止过拟合
  • T T T 为叶子数, w j w_j wj 为第 j j j 个叶子的得分

2.3 二阶泰勒展开优化

为了计算高效,XGBoost 不仅使用一阶梯度(残差),还使用 二阶导数(Hessian),即:

Obj^{(t)} \approx \sum_{i=1}^{n} \left[ g_i f_t(x_i) + \frac{1}{2} h_i f_t^2(x_i) \right] + \Omega(f_t)
  • g i = ∂ y ^ ( t − 1 ) l ( y i , y ^ ( t − 1 ) ) g_i = \partial_{\hat{y}^{(t-1)}} l(y_i, \hat{y}^{(t-1)}) gi=y^(t1)l(yi,y^(t1)):一阶导数
  • h i = ∂ y ^ ( t − 1 ) 2 l ( y i , y ^ ( t − 1 ) ) h_i = \partial^2_{\hat{y}^{(t-1)}} l(y_i, \hat{y}^{(t-1)}) hi=y^(t1)2l(yi,y^(t1)):二阶导数

这让模型更新更稳健,支持更多自定义损失函数。

2.4 树结构优化

使用贪心算法选择最优划分点
使用特征列块缓存加速多线程训练
支持剪枝(预剪枝和后剪枝)
使用 正则项 控制每棵树的复杂度(叶子数与叶子权重)

🔧 三、XGBoost 实现步骤(Python)

我们使用 xgboost 包进行实现,保持与前面 GBDT 案例一致。

采用以下命令安装库包:

pip install xgboostconda install myenv3.10
conda install xgboost

🧪 可调参数推荐

参数含义建议
n_estimators弱学习器数量100~500
learning_rate学习率0.01~0.3,越小越稳
max_depth树深度控制模型复杂度
subsample子样本比例防止过拟合,0.5~1.0
colsample_bytree每棵树用的特征比例防止过拟合
reg_alpha, reg_lambdaL1/L2 正则项控制过拟合(尤其重要)

完整案例代码(回归任务 + 可视化)

绘制的效果图如下:
在这里插入图片描述

左图:拟合效果:拟合曲线很好地捕捉了数据的非线性趋势。

  • 蓝点:训练数据
  • 红点:测试数据
  • 黑线:GBDT 拟合曲线

右图:残差图:残差应随机分布在 y=0 附近,没有明显模式,表明模型拟合良好。

输出结果为:(相同案例,比梯度提升决策树 GBDT效果差)

XGBoost Train MSE: 0.0147
XGBoost Test MSE: 0.0509

完整Python实现代码如下:

import numpy as np
import matplotlib.pyplot as plt
from xgboost import XGBRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error# 1. 生成数据(与 GBDT 案例相同)
np.random.seed(42)
X = np.linspace(0, 10, 200).reshape(-1, 1)
y = np.sin(X).ravel() + np.random.normal(0, 0.2, X.shape[0])# 2. 划分训练/测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 3. 训练 XGBoost 模型
model = XGBRegressor(n_estimators=100,learning_rate=0.1,max_depth=3,subsample=0.8,colsample_bytree=0.8,objective='reg:squarederror',  # 回归任务random_state=42
)model.fit(X_train, y_train)# 4. 预测与评估
y_train_pred = model.predict(X_train)
y_test_pred = model.predict(X_test)train_mse = mean_squared_error(y_train, y_train_pred)
test_mse = mean_squared_error(y_test, y_test_pred)print(f"XGBoost Train MSE: {train_mse:.4f}")
print(f"XGBoost Test MSE: {test_mse:.4f}")# 5. 可视化
plt.figure(figsize=(12, 6))# 拟合曲线图
plt.subplot(1, 2, 1)
plt.scatter(X_train, y_train, color='lightblue', label='Train Data', alpha=0.6)
plt.scatter(X_test, y_test, color='lightcoral', label='Test Data', alpha=0.6)X_all = np.linspace(0, 10, 1000).reshape(-1, 1)
y_all_pred = model.predict(X_all)
plt.plot(X_all, y_all_pred, color='black', label='XGBoost Prediction', linewidth=2)plt.title("XGBoost Model Fit")
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
plt.grid(True)# 残差图
plt.subplot(1, 2, 2)
train_residuals = y_train - y_train_pred
test_residuals = y_test - y_test_predplt.scatter(y_train_pred, train_residuals, color='blue', alpha=0.6, label='Train Residuals')
plt.scatter(y_test_pred, test_residuals, color='red', alpha=0.6, label='Test Residuals')
plt.axhline(y=0, color='black', linestyle='--')
plt.xlabel("Predicted y")
plt.ylabel("Residuals")
plt.title("Residual Plot")
plt.legend()
plt.grid(True)plt.tight_layout()
plt.show()

参考

http://www.lryc.cn/news/575400.html

相关文章:

  • Linux命令-Searching-locate
  • Docker compoes与私有仓库部署
  • 基于vue3+ByteMD快速搭建自己的Markdown文档编辑器
  • Midscene.js:使用 LLMs.txt 快速生成 AI 自动化测试用例「喂饭教程」
  • [Andrej Karpathy] 大型语言模型作为新型操作系统
  • 华为OD 机试 2025-黑板上色
  • 【25软考网工】第十章 网络规划与设计(2)网络规划与分析、网络结构与功能
  • 如何进行 iOS App 混淆加固?IPA 加壳与资源保护实战流程
  • 如何将视频从 iPhone 发送到 Android 设备
  • 数字孪生技术驱动UI前端变革:从静态展示到动态交互的飞跃
  • uniapp 和原生插件交互
  • 小程序入门:理解小程序页面配置
  • vue + vue-router写登陆验证的同步方法和异步方法,及页面组件的分离和后端代码
  • 命名数据网络 | 数据包(Data Packet)
  • chili3d笔记23 正交投影3d重建笔记4 点到线2
  • 【NLP】使用 LangGraph 构建 RAG 的Research Multi-Agent
  • house of apple2
  • Linux系统(信号篇):信号的产生
  • 【Pandas】pandas DataFrame shift
  • Ubuntu下布署mediasoup-demo
  • 黑马JVM解析笔记(四):Javap图解指令流程,深入理解Java字节码执行机制
  • Redis 为什么选用跳跃表,而不是红黑树
  • 《聊一聊ZXDoc》之汽车标定、台架标定、三高标定
  • 【STM32】外部中断
  • 【C++11】右值引用和移动语义
  • gRPC 使用(python 版本)
  • 2025学年湖北省职业院校技能大赛 “信息安全管理与评估”赛项 样题卷(五)
  • Axure版TDesign 组件库-免费版
  • MQTT 和 HTTP 有什么本质区别?
  • 如何将 Memfault 固件 SDK 集成到使用 Nordic 的 nRF Connect SDK(NCS)的项目中