当前位置: 首页 > news >正文

译 | 介绍PyTabKit:一个试图超越 Scikit-Learn的新机器学习库

github地址:https://github.com/dholzmueller/pytabkit
译原文地址:Get with the Times: PyTabKit for better Tabular Machine Learning over Sk-Learn (CODE Included)


长期以来,Scikit-Learn 一直是处理表格数据机器学习的首选库,提供了丰富的算法、预处理工具和模型评估功能。它仍然很出色,但为什么还要开着你爷爷那辆老旧的 58 年款雪佛兰车呢?让它保持古董地位吧。现在介绍 PyTabKit —— 一个新框架,旨在取代 Scikit-Learn,用于表格数据的分类和回归,采用了最新技术如 RealMLP 和为梯度提升树(GBDT)优化的默认超参数。

完整文章链接: 2407.04491

PyTabKit 提供了类似 scikit-learn 接口的现代表格分类和回归方法,并在我们的论文中进行了基准测试。它还包含了用于基准测试的相关代码。

支持的模型

  • 神经网络:RealMLP(调优默认、HPO、集成)
  • 梯度提升树:XGBoost、LightGBM、CatBoost(默认、调优、HPO)
  • 其他模型:TabR、TabM、ResNet 等

后处理和校准
支持时序缩放等后处理技术,提升预测概率的准确性。示例:

from pytabkit import RealMLP_TD_Classifierclf = RealMLP_TD_Classifier(val_metric_name='ref-ll-ts',  # 采用对数损失calibration_method='ts-mix',  # 时序缩放use_ls=False
)

为什么要超越 Scikit-Learn?

Scikit-Learn 为模型开发提供了坚实基础,但缺乏高度优化的深度学习方法和高效的自动调参功能。最新研究表明:

RealMLP 可与 GBDTs 竞争

  • 传统上,表格数据的深度学习模型需要大量调参,导致训练慢且不够实用。
  • RealMLP 是一个经过优化的多层感知机,基于 118 个数据集的基准测试进行了微调,在中等到大型数据集(1K 到 50 万样本)上性能可与 GBDTs 相媲美。
  • RealMLP 的改进包括稳健的数值缩放、数值嵌入和优化的权重初始化,使其成为传统模型的强有力替代。

更好的默认超参数很重要

  • Scikit-Learn 的默认超参数表现通常不如调优后的模型。
  • PyTabKit 为 XGBoost、LightGBM 和 CatBoost 提供了元调优的默认参数,能在无需调参的情况下超越 Scikit-Learn 的基线实现。
  • 这些默认设置在元训练基准上优化,并在 90 个未见过的数据集上验证了效果。

效率与准确性兼顾

  • 超参数优化代价高昂,尤其是深度学习模型。
  • PyTabKit 的优化默认配置让用户在许多情况下可以跳过调参,开箱即用,得到强劲效果。
  • 这使其成为 AutoML 系统中速度与准确性权衡的更佳选择。

RealMLP:表格数据神经网络的变革者

虽然梯度提升是结构化数据的主流方法,但深度学习如果正确实施,能缩小差距。RealMLP 引入了多项架构改进:

预处理改进

  • 对数值特征使用稳健缩放和平滑裁剪。
  • 对低基数类别特征使用独热编码。

架构增强

  • 引入对角权重层,提升表示能力。
  • 采用新颖的数值嵌入,优于传统特征变换。
  • 更智能的初始化策略,加快收敛速度。

性能提升

  • 基准测试显示 RealMLP 在某些场景下能匹配甚至超越 GBDTs。
  • 将 RealMLP 与优化的 GBDT 默认参数结合,能实现无需昂贵调参的最先进结果

未来展望:PyTabKit 作为新标准

PyTabKit 不仅是另一个机器学习库,而是一场范式转变。结合更强的神经网络架构、更优的默认超参数和实用的高效性,它有潜力取代 Scikit-Learn,成为许多实际应用的首选。

对于处理中等到大型数据集的用户,PyTabKit 提供了更快的训练速度、竞争力的准确率和减少的调参工作量,是现代机器学习工作流的理想方案。

代码示例:用 PyTabKit 训练 RealMLP 和树模型

使用方式和 Sklearn 一样简单!

安装

pip install pytabkit
pip install openml

获取数据集

这里使用 OpenML 的 Covertype 数据集,为了演示限制为 15,000 个样本。

import openml
from sklearn.model_selection import train_test_split
import numpy as nptask = openml.tasks.get_task(361113)
dataset = openml.datasets.get_dataset(task.dataset_id, download_data=False)
X, y, categorical_indicator, attribute_names = dataset.get_data(dataset_format='dataframe',target=task.target_name
)index = np.random.choice(range(len(X)), 15000, replace=False)
X = X.iloc[index]
y = y.iloc[index]X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

使用 RealMLP 训练

from pytabkit import RealMLP_TD_Classifier
from sklearn.metrics import accuracy_scoremodel = RealMLP_TD_Classifier()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy of RealMLP: {acc}")

预期输出:

Accuracy of RealMLP: 0.8770666666666667

使用 Bagging(交叉验证集成)

RealMLP 支持通过设置 n_cv=5 进行 5 折交叉验证集成,训练仍高效。

model = RealMLP_TD_Classifier(n_cv=5)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy of RealMLP with bagging: {acc}")

预期输出:

Accuracy of RealMLP with bagging: 0.8930666666666667

超参数优化

使用 RealMLP_HPO_Classifier 进行超参调优,调优步数可调。

from pytabkit import RealMLP_HPO_Classifiern_hyperopt_steps = 3
model = RealMLP_HPO_Classifier(n_hyperopt_steps=n_hyperopt_steps)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy of RealMLP with {n_hyperopt_steps} steps HPO: {acc}")

预期输出:

Accuracy of RealMLP with 3 steps HPO: 0.8605333333333334

使用优化默认参数的树模型

调优默认(TD)模型使用优化后的超参数,默认(D)模型使用库默认参数。

from pytabkit import (CatBoost_TD_Classifier, CatBoost_D_Classifier,LGBM_TD_Classifier, LGBM_D_Classifier,XGB_TD_Classifier, XGB_D_Classifier
)for model in [CatBoost_TD_Classifier(), CatBoost_D_Classifier(),LGBM_TD_Classifier(), LGBM_D_Classifier(),XGB_TD_Classifier(), XGB_D_Classifier()]:model.fit(X_train, y_train)y_pred = model.predict(X_test)acc = accuracy_score(y_test, y_pred)print(f"Accuracy of {model.__class__.__name__}: {acc}")

预期输出:

Accuracy of CatBoost_TD_Classifier: 0.8685333333333334
Accuracy of CatBoost_D_Classifier: 0.8464
Accuracy of LGBM_TD_Classifier: 0.8602666666666666
Accuracy of LGBM_D_Classifier: 0.8344
Accuracy of XGB_TD_Classifier: 0.8544
Accuracy of XGB_D_Classifier: 0.8472

集成优化默认参数的树模型和 RealMLP

通过集成多个模型可以建立强基线。

from pytabkit import Ensemble_TD_Classifiermodel = Ensemble_TD_Classifier()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy of Ensemble_TD_Classifier: {acc}")

以上内容展示了 PyTabKit 在表格数据机器学习中的强大能力,尤其是 RealMLP 的表现和优化默认参数的树模型,为实际应用提供了更高效、准确的解决方案。

http://www.lryc.cn/news/603374.html

相关文章:

  • 如何查询并访问路由器的默认网关(IP地址)?
  • 主应用严格模式下,子应用组件el-date-picker点击无效
  • 【Dify】-进阶14- 用 Dify 搭建法律文档解析助手
  • Vue.js 指令系统完全指南:深入理解 v- 指令
  • 智能图书馆管理系统开发实战系列(一):项目架构设计与技术选型
  • Ubuntu上开通Samba网络共享
  • Ambari 3.0.0 全网首发支持 Ubuntu 22!
  • Kafka——消费者组重平衡全流程解析
  • cpolar 内网穿透 ubuntu 使用石
  • Spark SQL 数组函数合集:array_agg、array_contains、array_sort…详解
  • 【MySQL】从连接数据库开始:JDBC 编程入门指南
  • Vim与VS Code
  • 【CodeTop】每日练习 2025.7.29
  • LibTorch使用-基础版
  • Jetpack - Room(Room 引入、Room 优化)
  • Spring Boot 自动配置:从 2.x 到 3.x 的进化之路
  • 牛顿拉夫逊法PQ分解法计算潮流MATLAB程序计算模型。
  • 微信小程序私密消息
  • GaussDB 数据库架构师修炼(十) 性能诊断常用视图
  • 原生html+js+jq+less 实现时间区间下拉弹窗选择器
  • 鸿蒙网络编程系列59-仓颉版TLS回声服务器示例
  • 42、鸿蒙HarmonyOS Next开发:应用上下文Context
  • Apache Ignite 的分布式原子类型(Atomic Types)
  • 专业Python爬虫实战教程:逆向加密接口与验证码突破完整案例
  • 【NLP舆情分析】基于python微博舆情分析可视化系统(flask+pandas+echarts) 视频教程 - 微博文章数据可视化分析-文章评论量分析实现
  • Apache Ignite Cluster Groups的介绍
  • U3D中的package
  • 【PHP】Swoole:CentOS安装Composer+Hyperf
  • vue2 使用liveplayer加载视频
  • .NET Core 3.1 升级到 .NET 8