当前位置: 首页 > news >正文

TabPFN - 表格数据基础模型

文章目录

    • 一、关于 TabPFN
      • 🌐TabPFN生态系统
    • 二、快速入门🏁
      • 1、安装
      • 2、基本用法
    • 三、使用技巧💡
    • 四、开发🛠️
      • 1、设置环境
      • 2、在提交之前
      • 3、运行测试


一、关于 TabPFN

TabPFN是表格数据的基础模型,它优于传统方法,同时速度显着加快。该存储库包含具有CUDA优化的核心PyTorch实现。

  • github : https://github.com/PriorLabs/TabPFN
  • 官方文档:https://priorlabs.ai/
  • Discord
  • 交互式Colab教程 使用示例和最佳实践
  • 开发者: Prior Labs

🌐TabPFN生态系统

根据您的需求选择正确的TabPFN实现:

  • TabPFN客户端:易于使用的API客户端,用于基于云的推理
  • TabPFN扩展:社区扩展和集成
  • TabPFN(此存储库):本地部署和研究的核心实现

试试我们的交互式Colab教程,快速入门。


二、快速入门🏁

⚠️ **主要更新:2.0版:**通过新的架构和功能完成代码库大修。以前的版本在v1.0.0和pip install tabpfn<2


1、安装

# Simple installation
pip install tabpfn# Local development installation
git clone https://github.com/PriorLabs/TabPFN.git
pip install -e "tabpfn[dev]"

2、基本用法

from sklearn.datasets import load_breast_cancer
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_splitfrom tabpfn import TabPFNClassifier# Load data
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)# Initialize a classifier
clf = TabPFNClassifier()
clf.fit(X_train, y_train)# Predict probabilities
prediction_probabilities = clf.predict_proba(X_test)
print("ROC AUC:", roc_auc_score(y_test, prediction_probabilities[:, 1]))# Predict labels
predictions = clf.predict(X_test)
print("Accuracy", accuracy_score(y_test, predictions))

三、使用技巧💡

TabPFN旨在以最少的预处理开箱即用:

  • 无需预处理:TabPFN在内部处理规范化
  • 类别变量:使用数字编码(浮点数表示有序,普通编码器表示无序)
  • 自动集成:控制与n_estimators
  • 独立预测:测试样本可以单独或批量预测
  • 可微:核心模型是可微的(预处理除外)
  • GPU支持:使用device='cuda'进行GPU加速

四、开发🛠️


1、设置环境

python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate
git clone https://github.com/PriorLabs/TabPFN.git
cd tabpfn
pip install -e ".[dev]"
pre-commit install

2、在提交之前

pre-commit run --all-files

3、运行测试

pytest tests/

2025-01-06(五)

http://www.lryc.cn/news/522060.html

相关文章:

  • AOF日志:宕机了Redis如何避免数据丢失?
  • MAC上安装Octave
  • C 语言中二维数组的退化
  • Notion 推出捏脸应用 | Deving Weekly #15
  • C# Linq 查询
  • ES7【2016】、ES8【2017】新增特性
  • 64细分步进电机驱动器TMC2209
  • C# 获取PDF文档中的字体信息(字体名、大小、颜色、样式等
  • linux 安装PrometheusAlert配置钉钉告警
  • 【华为路由/交换机的ssh远程设置】
  • 性能测试 - Locust WebSocket client
  • html中鼠标位置信息
  • kubernetes v1.29.XX版本HPA、KPA、VPA并压力测试
  • flutter 常用UI组件
  • HarmonyOS NEXT应用开发边学边玩系列:从零实现一影视APP (五、电影详情页的设计实现)
  • hive表修改字段类型没有级连导致历史分区报错
  • 云上贵州多彩宝荣获仓颉社区先锋应用奖 | 助力数字政务新突破
  • JS宏进阶:JS宏中的文件系统FileSystem
  • XML序列化和反序列化的学习
  • npm ERR! code CERT_HAS_EXPIRED
  • 30分钟内搭建一个全能轻量级springboot 3.4 + 脚手架 <5> 5分钟集成好caffeine并使用注解操作缓存
  • 【设计模式-结构型】装饰器模式
  • 分布式数据存储基础与HDFS操作实践(副本)
  • Linux 进程前篇(冯诺依曼体系结构和操作系统)
  • Springboot Redisson 分布式锁、缓存、消息队列、布隆过滤器
  • 【C语言】_字符串拷贝函数strcpy
  • 基于 Vue 的拖拽缩放卡片组件:实现思路、方法及使用指南
  • nginx 实现 正向代理、反向代理 、SSL(证书配置)、负载均衡 、虚拟域名 ,使用其他中间件监控
  • Kafka客户端-“远程主机强迫关闭了一个现有的连接”故障排查及解决
  • Node.js - Express框架