当前位置: 首页 > article >正文

使用AutoKeras2.0的AutoModel进行结构化数据回归预测

1、First of All: Read The Fucking Source Code

import autokeras as ak
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error# 生成数据集
np.random.seed(42)
x = np.random.rand(1000, 10)  # 生成1000个样本,每个样本有10个特征
y = x.dot([0.5, -1.5, 2.0, -0.8, 1.2, -0.3, 0.7, -1.1, 0.4, -0.6]) + 3.0  # 生成目标变量# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=42)input_node = ak.Input()
output_node = ak.DenseBlock()(input_node)  # 数值特征处理层
output_node = ak.RegressionHead()(output_node)# 初始化AutoModel,指定为回归任务
regressor = ak.AutoModel(inputs=input_node,  # 指定输入节点outputs=output_node,  # 指定输出节点overwrite=True,  # 覆盖之前的搜索结果max_trials=2  # 最大尝试次数
)# 训练模型
regressor.fit(x_train, y_train, epochs=10)# 评估模型
predictions = regressor.predict(x_test)
mse = mean_squared_error(y_test, predictions)
print('Mean Squared Error: {:.2f}'.format(mse))# 显示最佳模型信息
model = regressor.export_model()
model.summary()

 注:以上基于AutoKeras 2.0版本回归测试的例子,特别使用了通用的AutoModel的写法,其1.1版本不支持。

2、代码结构简介

AutoKeras是一个基于Keras的自动机器学习库,它能够自动搜索最优的神经网络架构。上述为一个简单的回归任务示例代码,主要包含以下部分:

  1. 生成模拟数据集并进行训练/测试集划分
  2. 定义AutoKeras的输入输出节点
  3. ‌配置并训练AutoModel
  4. ‌预测并计算均方误差
  5. ‌查看最佳模型结构

3、AutoModel输入输出说明

input_node = ak.Input()
output_node = ak.DenseBlock()(input_node)  # 数值特征处理层
output_node = ak.RegressionHead()(output_node)

输入部分:

  • input_node = ak.Input() 创建输入节点
  • 无需指定输入形状,AutoKeras会自动推断

处理层部分:

  • DenseBlock() 处理数值特征
  • 通过函数式API连接:output_node = ak.DenseBlock()(input_node)

输出部分:

  • RegressionHead() 指定回归任务
  • 同样使用函数式API连接

AutoModel配置:

  1. inputs参数接收输入节点
  2. outputs参数接收输出节点

 

http://www.lryc.cn/news/2386775.html

相关文章:

  • 好用但不常用的Git配置
  • ULVAC VWR-400M/ERH 真空蒸发器 Compact Vacuum Evaporator DEPOX (VWR-400M/ERH)
  • P1068 [NOIP 2009 普及组] 分数线划定
  • PPT连同备注页(演讲者模式)一块转为PDF
  • 第三十二天打卡
  • 项目三 - 任务8:实现词频统计功能
  • MongoDB 快速整合 SpringBoot 示例
  • 2025.05.22-得物春招机考真题解析-第二题
  • ollama list模型列表获取 接口代码
  • OPC Client第5讲(wxwidgets):初始界面的事件处理;按照配置文件初始化界面的内容
  • 什么是BFC,如何触发BFC,BFC有什么特性?
  • python做题日记(9)
  • Leetcode 3557. Find Maximum Number of Non Intersecting Substrings
  • 【C++进阶篇】初识哈希
  • Spring Boot——自动配置
  • 免费轻量便携截图 录屏 OCR 翻译四合一!提升办公效率
  • 使用 Vuex 实现用户注册与登录功能
  • 进程通信(管道,共享内存实现)
  • 电池预测 | 第28讲 基于CNN-GRU的锂电池剩余寿命预测
  • 快速上手SHELL脚本常用命令
  • 【无标题】前端如何实现分页?
  • 【自然语言处理与大模型】大模型Agent四大的组件
  • 小巧高效的目录索引生成软件
  • 云原生架构设计相关原则
  • android实现使用RecyclerView详细
  • 华为云Flexus+DeepSeek征文 | Flexus X实例助力 Dify-LLM 一键部署:性能跃升与成本优化的革新实践
  • 曼昆经济学原理第九版目录
  • 数据库blog7_MySql的下载与配置准备
  • YOLOv11助力地铁机场安检!!!一键识别刀具
  • RFID工业读写器的场景化应用选型指南