当前位置: 首页 > news >正文

NeuralForecast 推理 - 数据集从文件dataset.pkl读

NeuralForecast 推理 - 数据集从文件dataset.pkl读

flyfish

from ray import tune
from neuralforecast.core import NeuralForecast
from neuralforecast.auto import AutoMLP
from neuralforecast.models import NBEATS, NHITS
import torch
import torch.nn as nn
import os
import pickle
import warnings
from copy import deepcopy
from itertools import chain
from typing import Any, Dict, List, Optional, Unionimport fsspec
import numpy as np
import pandas as pdfrom neuralforecast.models import (GRU,LSTM,RNN,TCN,DeepAR,DilatedRNN,MLP,NHITS,NBEATS,NBEATSx,DLinear,NLinear,TFT,VanillaTransformer,Informer,Autoformer,FEDformer,StemGNN,PatchTST,TimesNet,TimeLLM,TSMixer,
)
MODEL_FILENAME_DICT = {"autoformer": Autoformer,"autoautoformer": Autoformer,"deepar": DeepAR,"autodeepar": DeepAR,"dlinear": DLinear,"autodlinear": DLinear,"nlinear": NLinear,"autonlinear": NLinear,"dilatedrnn": DilatedRNN,"autodilatedrnn": DilatedRNN,"fedformer": FEDformer,"autofedformer": FEDformer,"gru": GRU,"autogru": GRU,"informer": Informer,"autoinformer": Informer,"lstm": LSTM,"autolstm": LSTM,"mlp": MLP,"automlp": MLP,"nbeats": NBEATS,"autonbeats": NBEATS,"nbeatsx": NBEATSx,"autonbeatsx": NBEATSx,"nhits": NHITS,"autonhits": NHITS,"patchtst": PatchTST,"autopatchtst": PatchTST,"rnn": RNN,"autornn": RNN,"stemgnn": StemGNN,"autostemgnn": StemGNN,"tcn": TCN,"autotcn": TCN,"tft": TFT,"autotft": TFT,"timesnet": TimesNet,"autotimesnet": TimesNet,"vanillatransformer": VanillaTransformer,"autovanillatransformer": VanillaTransformer,"timellm": TimeLLM,"tsmixer": TSMixer,"autotsmixer": TSMixer,
}
#model_path1 = "checkpoints\\test_run\\automlp_0.ckpt"
model_path = "checkpoints\\test_run"dataset_path = "checkpoints\\test_run\\dataset.pkl"def load(path, verbose=False, **kwargs):# Standarize path without '/'if path[-1] == "/":path = path[:-1]fs, _, paths = fsspec.get_fs_token_paths(path)files = [f.split("/")[-1] for f in fs.ls(path) if fs.isfile(f)]# Load modelsmodels_ckpt = [f for f in files if f.endswith(".ckpt")]if len(models_ckpt) == 0:raise Exception("No model found in directory.")if verbose:print(10 * "-" + " Loading models " + 10 * "-")models = []try:with fsspec.open(f"{path}/alias_to_model.pkl", "rb") as f:alias_to_model = pickle.load(f)except FileNotFoundError:alias_to_model = {}for model in models_ckpt:model_name = model.split("_")[0]model_class_name = alias_to_model.get(model_name, model_name)models.append(MODEL_FILENAME_DICT[model_class_name].load_from_checkpoint(f"{path}/{model}", **kwargs))if verbose:print(f"Model {model_name} loaded.")return modelsmodels = load(model_path,verbose=True)
print(models[0])
model = models[0]
model.eval()# Load dataset
def load_dataset(path, verbose=True):try:with fsspec.open(f"{path}/dataset.pkl", "rb") as f:dataset = pickle.load(f)print(dataset)if verbose:print("Dataset loaded.")except FileNotFoundError:dataset = Noneif verbose:print("No dataset found in directory.")return datasetdata=pd.read_pickle(dataset_path)
print("data:",data)trimmed_dataset = load_dataset(model_path)
print(trimmed_dataset)#TimeSeriesDataset(n_data=96, n_groups=1)step_size =1
model_fcsts = model.predict(trimmed_dataset, step_size=step_size)
print(model_fcsts)
http://www.lryc.cn/news/359051.html

相关文章:

  • TS-类型转换(显式)
  • protobufjs 配置踩坑记录
  • freeswitch官方仓库
  • element ui el-calendar日历组件完整代码
  • 初识java——javaSE(8)异常
  • C语言面试题11至20题
  • 视频汇聚EasyCVR综合安防平台对接GA/T1400公安视图库及应用方案
  • 在Github找自己想要的的项目
  • 第16篇:JTAG UART IP应用<三>
  • Python——Selenium快速上手+方法(一站式解决问题)
  • 2024最新群智能优化算法:大甘蔗鼠算法(Greater Cane Rat Algorithm,GCRA)求解23个函数,提供MATLAB代码
  • 苍穹外卖数据可视化
  • AWS需要实名吗?
  • Android下HWC以及drm_hwcomposer普法(下)
  • 【评价类模型】熵权法
  • PG 窗口函数
  • 冯喜运:5.31晚间黄金原油行情分析及尾盘操作策略
  • Vue 框选区域放大(纯JavaScript实现)
  • C#加密与java 互通
  • C#【进阶】特殊语法
  • c语言之向文件读写数据块
  • 6键编程智能照明:编程指南与深度解析
  • sql server 中的6种约束
  • 师彼长技以助己(2)产品思维
  • Redis学习笔记【基础篇】
  • 【文献阅读】基于模型设计的汽车软件质量属性
  • 撸广告赚金币小游戏app开发
  • 海外高清短视频:四川京之华锦信息技术公司
  • 16:00面试,16:08就出来了,问的问题有点变态。。。
  • Android MediaCodec 简明教程(九):使用 MediaCodec 解码到纹理,使用 OpenGL ES 进行处理,并编码为 MP4 文件