当前位置: 首页 > news >正文

用 Python 轻松实现时间序列预测:Darts N-BEATS

文中内容仅限技术学习与代码实践参考,市场存在不确定性,技术分析需谨慎验证,不构成任何投资建议。

Darts

Darts 是一个 Python 库,用于对时间序列进行用户友好型预测和异常检测。它包含多种模型,从 ARIMA 等经典模型到深度神经网络。所有预测模型都能以类似 scikit-learn 的方式使用 fit()predict() 函数。该库还可以轻松地对模型进行回溯测试,将多个模型的预测结果结合起来,并将外部数据考虑在内。Darts 支持单变量和多变量时间序列和模型。基于 ML 的模型可以在包含多个时间序列的潜在大型数据集上进行训练,其中一些模型还为概率预测提供了丰富的支持。

N-BEATS

在这个 notebook 中,我们展示了如何使用 N-BEATS 与 darts。如果你是 darts 的新手,我们建议你先跟随快速入门 notebook。

N-BEATS 是一个最先进的模型,展示了 纯深度学习架构 在时间序列预测背景下的潜力。它在 M3M4 竞赛中超越了已确立的统计方法。有关模型的更多细节,请参见:https://arxiv.org/pdf/1905.10437.pdf。

# 如果在本地工作,修复 python 路径
from utils import fix_pythonpath_if_working_locallyfix_pythonpath_if_working_locally()
%matplotlib inline
import warningsimport matplotlib.pyplot as plt
import numpy as np
import pandas as pdfrom darts import TimeSeries, concatenate
from darts.dataprocessing.transformers import MissingValuesFiller, Scaler
from darts.datasets import EnergyDataset
from darts.metrics import r2_score
from darts.models import NBEATSModel
from darts.utils.callbacks import TFMProgressBarwarnings.filterwarnings("ignore")
import logginglogging.disable(logging.CRITICAL)def generate_torch_kwargs():# 在 CPU 上运行 torch 模型,并在除训练外的所有模型阶段禁用进度条。return {"pl_trainer_kwargs": {"accelerator": "cpu","callbacks": [TFMProgressBar(enable_train_bar_only=True)],}}
def display_forecast(pred_series, ts_transformed, forecast_type, start_date=None):plt.figure(figsize=(8, 5))if start_date:ts_transformed = ts_transformed.drop_before(start_date)ts_transformed.univariate_component(0).plot(label="实际")pred_series.plot(label=("历史 " + forecast_type + " 预测"))plt.title(f"R2: {r2_score(ts_transformed.univariate_component(0), pred_series)}")plt.legend()

每日能源发电示例

我们在一个来自径流式水电站的每日能源发电数据集上测试 NBEATS,因为它展示了不同程度的季节性。

df = EnergyDataset().load().to_dataframe()
df["generation hydro run-of-river and poundage"].plot()
plt.title("Hourly generation hydro run-of-river and poundage")
Text(0.5, 1.0, 'Hourly generation hydro run-of-river and poundage')

img

为了简化,我们使用每日发电量,并通过使用 MissingValuesFiller 填充数据中存在的缺失值:

df_day_avg = df.groupby(df.index.astype(str).str.split(" ").str[0]).mean().reset_index()
filler = MissingValuesFiller()
scaler = Scaler()
series = filler.transform(TimeSeries.from_dataframe(df_day_avg, "time", ["generation hydro run-of-river and poundage"])
).astype(np.float32)train, val = series.split_after(pd.Timestamp("20170901"))train_scaled = scaler.fit_transform(train)
val_scaled = scaler.transform(val)
series_scaled = scaler.transform(series)train_scaled.plot(label="训练")
val_scaled.plot(label="验证")
plt.title("Daily generation hydro run-of-river and poundage")
Text(0.5, 1.0, 'Daily generation hydro run-of-river and poundage')

img

我们将数据分为训练集和验证集。通常我们需要使用额外的测试集来在未见数据上验证模型,但在这个例子中我们将跳过它。

通用架构

N-BEATS 是一个单变量模型架构,提供两种配置:通用可解释通用架构 尽可能少地使用先验知识,没有特征工程、没有缩放,也没有可能被视为时间序列特定的内部架构组件。

首先,我们使用 N-BEATS 的通用架构模型:

model_name = "nbeats_run"
model_nbeats = NBEATSModel(input_chunk_length=30,output_chunk_length=7,generic_architecture=True,num_stacks=10,num_blocks=1,num_layers=4,layer_widths=512,n_epochs=100,nr_epochs_val_period=1,batch_size=800,random_state=42,model_name=model_name,save_checkpoints=True,force_reset=True,**generate_torch_kwargs(),
)
model_nbeats.fit(train_scaled, val_series=val_scaled)
NBEATSModel(generic_architecture=True, num_stacks=10, num_blocks=1, num_layers=4, layer_widths=512, expansion_coefficient_dim=5, trend_polynomial_degree=2, dropout=0.0, activation=ReLU, input_chunk_length=30, output_chunk_length=7, n_epochs=100, nr_epochs_val_period=1, batch_size=800, random_state=42, model_name=nbeats_run, save_checkpoints=True, force_reset=True, pl_trainer_kwargs={'accelerator': 'cpu', 'callbacks': [<darts.utils.callbacks.TFMProgressBar object at 0x2b3d98fd0>]})

让我们从在验证集上表现最佳的检查点加载模型。

model_nbeats = NBEATSModel.load_from_checkpoint(model_name=model_name, best=True)

让我们看看模型在扩展训练窗口和 7 天预测视界下会产生怎样的历史预测:

pred_series = model_nbeats.historical_forecasts(series_scaled,start=val.start_time(),forecast_horizon=7,stride=7,last_points_only=False,retrain=False,verbose=True,
)
pred_series = concatenate(pred_series)
display_forecast(pred_series,series_scaled,"7 天",start_date=val.start_time(),
)

img

可解释模型

N-BEATS 提供了一个 可解释架构,由两个栈组成:一个 趋势 栈和一个 季节性 栈。该架构设计为:

  • 在将输入馈送到季节性栈之前,移除趋势成分

  • 趋势和季节性的部分预测作为单独的可解释输出可用

model_name = "nbeats_interpretable_run"
model_nbeats = NBEATSModel(input_chunk_length=30,output_chunk_length=7,generic_architecture=False,num_blocks=3,num_layers=4,layer_widths=512,n_epochs=100,nr_epochs_val_period=1,batch_size=800,random_state=42,model_name=model_name,save_checkpoints=True,force_reset=True,**generate_torch_kwargs(),
)
model_nbeats.fit(series=train_scaled, val_series=val_scaled)
NBEATSModel(generic_architecture=False, num_stacks=30, num_blocks=3, num_layers=4, layer_widths=512, expansion_coefficient_dim=5, trend_polynomial_degree=2, dropout=0.0, activation=ReLU, input_chunk_length=30, output_chunk_length=7, n_epochs=100, nr_epochs_val_period=1, batch_size=800, random_state=42, model_name=nbeats_interpretable_run, save_checkpoints=True, force_reset=True, pl_trainer_kwargs={'accelerator': 'cpu', 'callbacks': [<darts.utils.callbacks.TFMProgressBar object at 0x2b3fc0790>]})
model_nbeats = NBEATSModel.load_from_checkpoint(model_name=model_name, best=True)

让我们看看模型在扩展训练窗口和 7 天预测视界下会产生怎样的历史预测:

pred_series = model_nbeats.historical_forecasts(series_scaled,start=val_scaled.start_time(),forecast_horizon=7,stride=7,last_points_only=False,retrain=False,verbose=True,
)
pred_series = concatenate(pred_series)
display_forecast(pred_series, series_scaled, "7 day", start_date=val_scaled.start_time()
)

img

风险提示与免责声明
本文内容基于公开信息研究整理,不构成任何形式的投资建议。历史表现不应作为未来收益保证,市场存在不可预见的波动风险。投资者需结合自身财务状况及风险承受能力独立决策,并自行承担交易结果。作者及发布方不对任何依据本文操作导致的损失承担法律责任。市场有风险,投资须谨慎。

http://www.lryc.cn/news/601482.html

相关文章:

  • 安卓怎么做一个像QQ一样的开关切换控件
  • 墨者:通过手工解决SQL手工注入漏洞测试(MongoDB数据库)
  • 机器学习特征选择 explanation and illustration of ANOVA
  • net8.0一键创建支持(Redis)
  • 【机器学习】第七章 特征工程
  • 基于大模型的预训练、量化、微调等完整流程解析
  • CLAP文本-音频基础模型: LEARNING AUDIO CONCEPTS FROM NATURAL LANGUAGE SUPERVISION
  • PDF文件被加密限制怎么办?专业级解除方案分享
  • 51核和ARM核单片机OTA实战解析(一)
  • 一分钟部署一个导航网站
  • MCU 通用AT指令处理框架
  • PDF转图片实用指南:如何批量高效转换?
  • 创建的springboot工程java文件夹下还是文件夹而不是包
  • 内网服务器实现从公网穿透
  • 单片机ADC采集机理层面详细分析(二)
  • 零基础学习性能测试第五章:JVM性能分析与调优-多线程检测与瓶颈分析
  • 【C语言网络编程基础】TCP 服务器详解
  • Rust与Java DynamoDB、MySQL CRM、tokio-pg、SVM、Custors实战指南
  • 墨者:通过手动解决SQL手工注入漏洞测试(MySQL数据库)
  • Wireshark TS | 发送数据超出接收窗口
  • 双面15.6寸智能访客机硬件规格书及对接第三方接口说明
  • 力扣 hot100 Day57
  • 数据江湖的“三国演义”:数据仓库、数据湖与湖仓一体的全景对比
  • 区块链:工作量证明与联邦学习
  • 神经网络知识讨论
  • 【旧文】Adobe Express使用教程
  • 7月27日星期日今日早报简报微语报早读
  • 数据赋能(340)——技术平台——共享平台
  • Spring之【Bean的生命周期】
  • 视频转GIF工具,一键批量制作高清动图