用 Python 轻松实现时间序列预测:Darts N-BEATS
文中内容仅限技术学习与代码实践参考,市场存在不确定性,技术分析需谨慎验证,不构成任何投资建议。
Darts 是一个 Python 库,用于对时间序列进行用户友好型预测和异常检测。它包含多种模型,从 ARIMA 等经典模型到深度神经网络。所有预测模型都能以类似 scikit-learn 的方式使用 fit()
和 predict()
函数。该库还可以轻松地对模型进行回溯测试,将多个模型的预测结果结合起来,并将外部数据考虑在内。Darts 支持单变量和多变量时间序列和模型。基于 ML 的模型可以在包含多个时间序列的潜在大型数据集上进行训练,其中一些模型还为概率预测提供了丰富的支持。
N-BEATS
在这个 notebook 中,我们展示了如何使用 N-BEATS 与 darts。如果你是 darts 的新手,我们建议你先跟随快速入门 notebook。
N-BEATS 是一个最先进的模型,展示了 纯深度学习架构 在时间序列预测背景下的潜力。它在 M3 和 M4 竞赛中超越了已确立的统计方法。有关模型的更多细节,请参见:https://arxiv.org/pdf/1905.10437.pdf。
# 如果在本地工作,修复 python 路径
from utils import fix_pythonpath_if_working_locallyfix_pythonpath_if_working_locally()
%matplotlib inline
import warningsimport matplotlib.pyplot as plt
import numpy as np
import pandas as pdfrom darts import TimeSeries, concatenate
from darts.dataprocessing.transformers import MissingValuesFiller, Scaler
from darts.datasets import EnergyDataset
from darts.metrics import r2_score
from darts.models import NBEATSModel
from darts.utils.callbacks import TFMProgressBarwarnings.filterwarnings("ignore")
import logginglogging.disable(logging.CRITICAL)def generate_torch_kwargs():# 在 CPU 上运行 torch 模型,并在除训练外的所有模型阶段禁用进度条。return {"pl_trainer_kwargs": {"accelerator": "cpu","callbacks": [TFMProgressBar(enable_train_bar_only=True)],}}
def display_forecast(pred_series, ts_transformed, forecast_type, start_date=None):plt.figure(figsize=(8, 5))if start_date:ts_transformed = ts_transformed.drop_before(start_date)ts_transformed.univariate_component(0).plot(label="实际")pred_series.plot(label=("历史 " + forecast_type + " 预测"))plt.title(f"R2: {r2_score(ts_transformed.univariate_component(0), pred_series)}")plt.legend()
每日能源发电示例
我们在一个来自径流式水电站的每日能源发电数据集上测试 NBEATS,因为它展示了不同程度的季节性。
df = EnergyDataset().load().to_dataframe()
df["generation hydro run-of-river and poundage"].plot()
plt.title("Hourly generation hydro run-of-river and poundage")
Text(0.5, 1.0, 'Hourly generation hydro run-of-river and poundage')
为了简化,我们使用每日发电量,并通过使用 MissingValuesFiller
填充数据中存在的缺失值:
df_day_avg = df.groupby(df.index.astype(str).str.split(" ").str[0]).mean().reset_index()
filler = MissingValuesFiller()
scaler = Scaler()
series = filler.transform(TimeSeries.from_dataframe(df_day_avg, "time", ["generation hydro run-of-river and poundage"])
).astype(np.float32)train, val = series.split_after(pd.Timestamp("20170901"))train_scaled = scaler.fit_transform(train)
val_scaled = scaler.transform(val)
series_scaled = scaler.transform(series)train_scaled.plot(label="训练")
val_scaled.plot(label="验证")
plt.title("Daily generation hydro run-of-river and poundage")
Text(0.5, 1.0, 'Daily generation hydro run-of-river and poundage')
我们将数据分为训练集和验证集。通常我们需要使用额外的测试集来在未见数据上验证模型,但在这个例子中我们将跳过它。
通用架构
N-BEATS 是一个单变量模型架构,提供两种配置:通用 和 可解释。通用架构 尽可能少地使用先验知识,没有特征工程、没有缩放,也没有可能被视为时间序列特定的内部架构组件。
首先,我们使用 N-BEATS 的通用架构模型:
model_name = "nbeats_run"
model_nbeats = NBEATSModel(input_chunk_length=30,output_chunk_length=7,generic_architecture=True,num_stacks=10,num_blocks=1,num_layers=4,layer_widths=512,n_epochs=100,nr_epochs_val_period=1,batch_size=800,random_state=42,model_name=model_name,save_checkpoints=True,force_reset=True,**generate_torch_kwargs(),
)
model_nbeats.fit(train_scaled, val_series=val_scaled)
NBEATSModel(generic_architecture=True, num_stacks=10, num_blocks=1, num_layers=4, layer_widths=512, expansion_coefficient_dim=5, trend_polynomial_degree=2, dropout=0.0, activation=ReLU, input_chunk_length=30, output_chunk_length=7, n_epochs=100, nr_epochs_val_period=1, batch_size=800, random_state=42, model_name=nbeats_run, save_checkpoints=True, force_reset=True, pl_trainer_kwargs={'accelerator': 'cpu', 'callbacks': [<darts.utils.callbacks.TFMProgressBar object at 0x2b3d98fd0>]})
让我们从在验证集上表现最佳的检查点加载模型。
model_nbeats = NBEATSModel.load_from_checkpoint(model_name=model_name, best=True)
让我们看看模型在扩展训练窗口和 7 天预测视界下会产生怎样的历史预测:
pred_series = model_nbeats.historical_forecasts(series_scaled,start=val.start_time(),forecast_horizon=7,stride=7,last_points_only=False,retrain=False,verbose=True,
)
pred_series = concatenate(pred_series)
display_forecast(pred_series,series_scaled,"7 天",start_date=val.start_time(),
)
可解释模型
N-BEATS 提供了一个 可解释架构,由两个栈组成:一个 趋势 栈和一个 季节性 栈。该架构设计为:
-
在将输入馈送到季节性栈之前,移除趋势成分
-
趋势和季节性的部分预测作为单独的可解释输出可用
model_name = "nbeats_interpretable_run"
model_nbeats = NBEATSModel(input_chunk_length=30,output_chunk_length=7,generic_architecture=False,num_blocks=3,num_layers=4,layer_widths=512,n_epochs=100,nr_epochs_val_period=1,batch_size=800,random_state=42,model_name=model_name,save_checkpoints=True,force_reset=True,**generate_torch_kwargs(),
)
model_nbeats.fit(series=train_scaled, val_series=val_scaled)
NBEATSModel(generic_architecture=False, num_stacks=30, num_blocks=3, num_layers=4, layer_widths=512, expansion_coefficient_dim=5, trend_polynomial_degree=2, dropout=0.0, activation=ReLU, input_chunk_length=30, output_chunk_length=7, n_epochs=100, nr_epochs_val_period=1, batch_size=800, random_state=42, model_name=nbeats_interpretable_run, save_checkpoints=True, force_reset=True, pl_trainer_kwargs={'accelerator': 'cpu', 'callbacks': [<darts.utils.callbacks.TFMProgressBar object at 0x2b3fc0790>]})
model_nbeats = NBEATSModel.load_from_checkpoint(model_name=model_name, best=True)
让我们看看模型在扩展训练窗口和 7 天预测视界下会产生怎样的历史预测:
pred_series = model_nbeats.historical_forecasts(series_scaled,start=val_scaled.start_time(),forecast_horizon=7,stride=7,last_points_only=False,retrain=False,verbose=True,
)
pred_series = concatenate(pred_series)
display_forecast(pred_series, series_scaled, "7 day", start_date=val_scaled.start_time()
)
风险提示与免责声明
本文内容基于公开信息研究整理,不构成任何形式的投资建议。历史表现不应作为未来收益保证,市场存在不可预见的波动风险。投资者需结合自身财务状况及风险承受能力独立决策,并自行承担交易结果。作者及发布方不对任何依据本文操作导致的损失承担法律责任。市场有风险,投资须谨慎。