译 | 介绍PyTabKit:一个试图超越 Scikit-Learn的新机器学习库
github地址:https://github.com/dholzmueller/pytabkit
译原文地址:Get with the Times: PyTabKit for better Tabular Machine Learning over Sk-Learn (CODE Included)
长期以来,Scikit-Learn 一直是处理表格数据机器学习的首选库,提供了丰富的算法、预处理工具和模型评估功能。它仍然很出色,但为什么还要开着你爷爷那辆老旧的 58 年款雪佛兰车呢?让它保持古董地位吧。现在介绍 PyTabKit —— 一个新框架,旨在取代 Scikit-Learn,用于表格数据的分类和回归,采用了最新技术如 RealMLP 和为梯度提升树(GBDT)优化的默认超参数。
完整文章链接: 2407.04491
PyTabKit 提供了类似 scikit-learn 接口的现代表格分类和回归方法,并在我们的论文中进行了基准测试。它还包含了用于基准测试的相关代码。
支持的模型
- 神经网络:RealMLP(调优默认、HPO、集成)
- 梯度提升树:XGBoost、LightGBM、CatBoost(默认、调优、HPO)
- 其他模型:TabR、TabM、ResNet 等
后处理和校准
支持时序缩放等后处理技术,提升预测概率的准确性。示例:
from pytabkit import RealMLP_TD_Classifierclf = RealMLP_TD_Classifier(val_metric_name='ref-ll-ts', # 采用对数损失calibration_method='ts-mix', # 时序缩放use_ls=False
)
为什么要超越 Scikit-Learn?
Scikit-Learn 为模型开发提供了坚实基础,但缺乏高度优化的深度学习方法和高效的自动调参功能。最新研究表明:
RealMLP 可与 GBDTs 竞争
- 传统上,表格数据的深度学习模型需要大量调参,导致训练慢且不够实用。
- RealMLP 是一个经过优化的多层感知机,基于 118 个数据集的基准测试进行了微调,在中等到大型数据集(1K 到 50 万样本)上性能可与 GBDTs 相媲美。
- RealMLP 的改进包括稳健的数值缩放、数值嵌入和优化的权重初始化,使其成为传统模型的强有力替代。
更好的默认超参数很重要
- Scikit-Learn 的默认超参数表现通常不如调优后的模型。
- PyTabKit 为 XGBoost、LightGBM 和 CatBoost 提供了元调优的默认参数,能在无需调参的情况下超越 Scikit-Learn 的基线实现。
- 这些默认设置在元训练基准上优化,并在 90 个未见过的数据集上验证了效果。
效率与准确性兼顾
- 超参数优化代价高昂,尤其是深度学习模型。
- PyTabKit 的优化默认配置让用户在许多情况下可以跳过调参,开箱即用,得到强劲效果。
- 这使其成为 AutoML 系统中速度与准确性权衡的更佳选择。
RealMLP:表格数据神经网络的变革者
虽然梯度提升是结构化数据的主流方法,但深度学习如果正确实施,能缩小差距。RealMLP 引入了多项架构改进:
预处理改进
- 对数值特征使用稳健缩放和平滑裁剪。
- 对低基数类别特征使用独热编码。
架构增强
- 引入对角权重层,提升表示能力。
- 采用新颖的数值嵌入,优于传统特征变换。
- 更智能的初始化策略,加快收敛速度。
性能提升
- 基准测试显示 RealMLP 在某些场景下能匹配甚至超越 GBDTs。
- 将 RealMLP 与优化的 GBDT 默认参数结合,能实现无需昂贵调参的最先进结果。
未来展望:PyTabKit 作为新标准
PyTabKit 不仅是另一个机器学习库,而是一场范式转变。结合更强的神经网络架构、更优的默认超参数和实用的高效性,它有潜力取代 Scikit-Learn,成为许多实际应用的首选。
对于处理中等到大型数据集的用户,PyTabKit 提供了更快的训练速度、竞争力的准确率和减少的调参工作量,是现代机器学习工作流的理想方案。
代码示例:用 PyTabKit 训练 RealMLP 和树模型
使用方式和 Sklearn 一样简单!
安装
pip install pytabkit
pip install openml
获取数据集
这里使用 OpenML 的 Covertype 数据集,为了演示限制为 15,000 个样本。
import openml
from sklearn.model_selection import train_test_split
import numpy as nptask = openml.tasks.get_task(361113)
dataset = openml.datasets.get_dataset(task.dataset_id, download_data=False)
X, y, categorical_indicator, attribute_names = dataset.get_data(dataset_format='dataframe',target=task.target_name
)index = np.random.choice(range(len(X)), 15000, replace=False)
X = X.iloc[index]
y = y.iloc[index]X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
使用 RealMLP 训练
from pytabkit import RealMLP_TD_Classifier
from sklearn.metrics import accuracy_scoremodel = RealMLP_TD_Classifier()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy of RealMLP: {acc}")
预期输出:
Accuracy of RealMLP: 0.8770666666666667
使用 Bagging(交叉验证集成)
RealMLP 支持通过设置 n_cv=5
进行 5 折交叉验证集成,训练仍高效。
model = RealMLP_TD_Classifier(n_cv=5)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy of RealMLP with bagging: {acc}")
预期输出:
Accuracy of RealMLP with bagging: 0.8930666666666667
超参数优化
使用 RealMLP_HPO_Classifier
进行超参调优,调优步数可调。
from pytabkit import RealMLP_HPO_Classifiern_hyperopt_steps = 3
model = RealMLP_HPO_Classifier(n_hyperopt_steps=n_hyperopt_steps)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy of RealMLP with {n_hyperopt_steps} steps HPO: {acc}")
预期输出:
Accuracy of RealMLP with 3 steps HPO: 0.8605333333333334
使用优化默认参数的树模型
调优默认(TD)模型使用优化后的超参数,默认(D)模型使用库默认参数。
from pytabkit import (CatBoost_TD_Classifier, CatBoost_D_Classifier,LGBM_TD_Classifier, LGBM_D_Classifier,XGB_TD_Classifier, XGB_D_Classifier
)for model in [CatBoost_TD_Classifier(), CatBoost_D_Classifier(),LGBM_TD_Classifier(), LGBM_D_Classifier(),XGB_TD_Classifier(), XGB_D_Classifier()]:model.fit(X_train, y_train)y_pred = model.predict(X_test)acc = accuracy_score(y_test, y_pred)print(f"Accuracy of {model.__class__.__name__}: {acc}")
预期输出:
Accuracy of CatBoost_TD_Classifier: 0.8685333333333334
Accuracy of CatBoost_D_Classifier: 0.8464
Accuracy of LGBM_TD_Classifier: 0.8602666666666666
Accuracy of LGBM_D_Classifier: 0.8344
Accuracy of XGB_TD_Classifier: 0.8544
Accuracy of XGB_D_Classifier: 0.8472
集成优化默认参数的树模型和 RealMLP
通过集成多个模型可以建立强基线。
from pytabkit import Ensemble_TD_Classifiermodel = Ensemble_TD_Classifier()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Accuracy of Ensemble_TD_Classifier: {acc}")
以上内容展示了 PyTabKit 在表格数据机器学习中的强大能力,尤其是 RealMLP 的表现和优化默认参数的树模型,为实际应用提供了更高效、准确的解决方案。