当前位置: 首页 > news >正文

【GridSearch】 简单实现并记录运行效果

记录了使用for循环实现网格搜索的简单框架。
使用df_search记录每种超参数组合下的运行结果。
lgb_model.best_score返回模型的最佳得分
lgb_model.best_iteration_返回模型的最佳iteration也就是最佳n_extimator

import numpy as np
import pandas as pd
import lightgbm as lgbdf = pd.read_csv("this_is_train.csv")
df_search_columns = ['learning_rate', 'num_leaves', 'max_depth','subsample','colsample_bytree','best_iteration','best_score']
df_search =  pd.DataFrame(columns=df_search_columns )
# colsample_bytree :0.9, learning_rate : 0.001
lgb_params = {"objective": "mae", # "mae""n_estimators": 6000,"num_leaves": 256, # 256"subsample": 0.6,"colsample_bytree": 0.8,"learning_rate": 0.00571, # 0.00871'max_depth': 11, # 11"n_jobs": 4,"device": "gpu","verbosity": -1,"importance_type": "gain",
}
for learning_rate in [0.001,0.005,0.01,0.015,0.05]:for num_leaves in [300,256,200,150]:for max_depth in [15,13,11,9,7]:for subsample in [0.8,0.6,0.5]:for colsample_bytree in [0.9,0.8,0.7]:print(f"learning_rate : {learning_rate}, num_leaves : {num_leaves}, max_depth:{max_depth}, subsample : {subsample}, colsample_bytree : {colsample_bytree}")lgb_params['learning_rate'] = learning_ratelgb_params['num_leaves'] = num_leaveslgb_params['max_depth'] = max_depthlgb_params['subsample'] = subsamplelgb_params['colsample_bytree'] = colsample_bytree# Train a LightGBM model for the current foldlgb_model = lgb.LGBMRegressor(**lgb_params)lgb_model.fit(train_feats,train_target,eval_set=[(valid_feats, valid_target)],callbacks=[lgb.callback.early_stopping(stopping_rounds=100),lgb.callback.log_evaluation(period=100),],)best_iteration = lgb_model.best_iteration_best_score = lgb_model.best_score_cache = pd.DataFrame([[learning_rate,num_leaves,max_depth,subsample,colsample_bytree,best_iteration,best_score]],columns=['learning_rate', 'num_leaves', 'max_depth','subsample','colsample_bytree','best_iteration','best_score'])df_search = pd.concat([df_search, cache], ignore_index=True, axis=0)
df_search.to_csv('grid_search.csv',index=False)

使用该框架,需要调整训练数据df部分,以及进行网格的备选数据和lightgbm的超参数。
每次运行的数据通过一下代码进行记录

cache = pd.DataFrame([[learning_rate,num_leaves,max_depth,subsample,colsample_bytree,best_iteration,best_score]],\columns=df_search_columns )
df_search = pd.concat([df_search, cache], ignore_index=True, axis=0)
http://www.lryc.cn/news/242721.html

相关文章:

  • SecureCRT出现Key exchange failed.No compatible key exchange method. 错误解决方法
  • Android RGB转YUV的算法
  • Spring事务底层原理(待完善)
  • 微信小程序 修改默认单选,多选按钮样式
  • 「最优化基础知识2」一维搜索,以及python代码
  • 工厂模式之抽象工厂模式(常用)
  • Apache服务Rwrite功能使用
  • 【一起来学kubernetes】6、kubernetes基本概念区分
  • Python基础入门例程66-NP66 增加元组的长度(元组)
  • ubuntu22.04 安装 jupyterlab
  • 探索移动端可能性:Capacitor5.5.1和vue2在Android studio中精细融合
  • 【深度学习】Python快捷调用InsightFace人脸检测,纯ONNX推理
  • JAVA序列化和反序列化
  • 基于浣熊算法优化概率神经网络PNN的分类预测 - 附代码
  • uni-app打包后,打开软件时使其横屏显示
  • MYSQL基础知识之【创建,删除,选择数据库】
  • 关于 token 和证书
  • 基于SSM和微信小程序的场地预约网站
  • Javascript每天一道算法题(十七)——缺失的第一个正整数_困难
  • 【React】路径别名配置
  • 前缀和——238. 除自身以外数组的乘积
  • MySql数据库常用指令(二)
  • zookeeper 单机伪集群搭建简单记录
  • 【Linux】匿名管道与命名管道,进程池的简易实现
  • HTML5+ API 爬坑记录
  • idea git将某个分支内的commit合并到其他分支
  • Google hacking语法
  • Redis集群(新)
  • [JVM] 常用调优参数
  • 【nlp】3.5 Transformer论文复现:3.解码器部分(解码器层)和4.输出部分(线性层、softmax层)