当前位置: 首页 > news >正文

Paddlets时间序列集成模型回测实战:MLPRegressor、NHiTSModel与RNNBlockRegressor

好的,我们继续深入理解代码的每个部分。以下是每个主要模块的详细解释:

1. 导入模块和库

import json
import os
import glob
import pandas as pd
from tqdm import tqdm
from paddlets.datasets import TSDataset
from paddlets.transform import StandardScaler
from paddlets.models.forecasting import MLPRegressor, NHiTSModel, RNNBlockRegressor
from paddlets.ensemble import WeightingEnsembleForecaster
import ray
  • json: 用于处理JSON文件。
  • os: 用于处理文件和目录操作。
  • glob: 用于获取匹配特定模式的文件路径。
  • pandas: 用于数据处理和分析,尤其是表格数据。
  • tqdm: 用于显示进度条,帮助跟踪循环的进度。
  • paddlets: 时间序列预测相关的库。
  • ray: 用于并行计算的库。

2. 定义和创建目录

dirs = ["forecasting_all_result_center","pic_forecasting_center","model_forecasting_center_2048_a_b_5_100","best_forecasting_param_center"
]for dir_name in dirs:os.makedirs(dir_name, exist_ok=True)
  • dirs: 定义了多个用于存储不同类型结果的目录。
  • os.makedirs: 创建目录,如果目录已存在,则不报错。

3. 加载股票映射

with open("./stock_mapping.json", "r") as f:stock_mapping = json.load(f)
  • stock_mapping.json文件中加载股票的映射关系,以便后续使用。

4. 加载CSV数据

csv_paths = glob.glob(os.path.join("./tu_share_data_day", "*.csv"))
sum_dam_data = []for csv_path in tqdm(csv_paths):new_data = pd.read_csv(csv_path)if len(new_data) < 2048 or new_data.iloc[0, 2] < 5 or new_data.iloc[0, 2] > 100:continuenew_data = new_data[::-1].iloc[:2048]new_data['index_new'] = range(1, len(new_data) + 1)sum_dam_data.append(new_data)
  • 使用glob获取所有CSV文件路径,并遍历每个文件。
  • 读取数据并进行过滤,确保符合条件(如数据长度、价格区间)。
  • 将数据反转并取最后2048条,添加索引列。

5. 构建时间序列数据集

dam_data = pd.concat(sum_dam_data)dataset = TSDataset.load_from_dataframe(dam_data,group_id='ts_code',time_col="index_new",target_cols=['high', 'low']
)
  • 将所有符合条件的数据合并成一个DataFrame。
  • 使用TSDataset将数据转换为时间序列格式,指定分组、时间列和目标列。

6. 初始化标准化器

scaler = StandardScaler().fit(dataset)
dataset = scaler.transform(dataset)
  • 使用StandardScaler对数据进行标准化处理,使模型训练更加稳定。

7. 初始化Ray进行并行计算

ray.init()
  • 初始化Ray,使得后续的计算能够并行执行。

8. 定义并行处理函数

@ray.remote
def process_csv_file(csv_path, scaler):...
  • 使用@ray.remote装饰器定义一个可以被Ray并行化的函数,处理每个CSV文件的逻辑。

9. 设置模型参数和加载模型

nhits_params = {'sampling_stride': 24, 'eval_metrics': ["mse", "mae"], 'batch_size': 32, 'max_epochs': 100, 'patience': 10}
rnn_params = nhits_params.copy()
mlp_params = nhits_params.copy()
mlp_params['use_bn'] =
http://www.lryc.cn/news/448214.html

相关文章:

  • 【anki】显示 “连接超时,请更换网络后重试” 怎么办
  • 第一批学习大模型的程序员,已经碾压同事了,薪资差距都甩出一条街了...
  • Unity NetCode 客户端连接不上服务器,局域网模式 Failed to connect to server.
  • C++远端开发环境安装(centos7)
  • LaTeX 编辑器-TeXstudio
  • [深度学习]循环神经网络
  • 景联文科技精准数据标注:优化智能标注平台,打造智能未来
  • 商场促销——策略模式
  • 万字长文,AIGC算法工程师的面试秘籍,推荐收藏!
  • 一些超好用的 GitHub 插件和技巧
  • 记Flink SQL 将数据写入 MySQL时的一个优化策略
  • QT-自定义信号和槽对象树图形化开发计算器
  • C# 字符串(String)的应用说明一
  • Redis缓存淘汰算法详解
  • Sklearn 与 TensorFlow 机器学习实用指南
  • RabbitMQ 界面管理说明
  • 设备管理与点巡检系统
  • 计算机网络的整体认识---网络协议,网络传输过程
  • Battery management system (BMS)
  • 和GPT讨论ZNS的问题(无修改)
  • 6.8方框滤波
  • 携手SelectDB,观测云实现性能与成本的双重飞跃
  • Redis 五大基本数据类型及其应用场景进阶(缓存预热、雪崩 、穿透 、击穿)
  • 如何在ChatGPT的帮助下,使用“逻辑回归”技巧完成论文写作?
  • MySQL 临时表
  • 个人文章汇总(算法原理算法题)
  • 基于Hive和Hadoop的图书分析系统
  • 阿里rtc云端录制TypeScript版NODE运行
  • Web后端开发原理!!!什么是自动配置???什么是起动依赖???
  • 2-105 基于matlab的GA-WNN预测算法