网站抓取优化,网站建设上海网站制作,开发公司是什么,app安装器前言#xff1a;
最近在大模型预测#xff0c;简单了解了lag-llama开源项目#xff0c;网上也有很多讲解原理的#xff0c;这里就将如何快速上手使用说一下#xff0c;只懂得一点点皮毛#xff0c;有错误的地方欢迎大佬指出。
简单介绍#xff1a;
Lag-Llama 是一个开…前言
最近在大模型预测简单了解了lag-llama开源项目网上也有很多讲解原理的这里就将如何快速上手使用说一下只懂得一点点皮毛有错误的地方欢迎大佬指出。
简单介绍
Lag-Llama 是一个开源的时间序列预测模型基于 Transformer 架构设计专注于利用 滞后特征Lagged Features 捕捉时间序列的长期依赖关系。其核心思想是将传统时间序列分析中的滞后算子Lags与现代深度学习结合实现对复杂时序模式的高效建模。
GitHup地址GitHub - time-series-foundation-models/lag-llama: Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting
相关技术原理...搜一下很多文章讲的都非常好 实现模型预测
1.下载模型文件
从 HuggingFace下载如果网络原因访问不了建议从魔搭社区下载lag-Llama · 模型库 2.准备数据集
参考文档pandas.DataFrame based dataset - GluonTS documentation 以我测试数据举例 3.完整代码需要替换模型文件地址和数据集地址
from itertools import islicefrom matplotlib import pyplot as plt
import matplotlib.dates as mdatesimport torch
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_datasetfrom gluonts.dataset.pandas import PandasDataset
import pandas as pdfrom lag_llama.gluon.estimator import LagLlamaEstimatordef get_lag_llama_predictions(dataset, prediction_length, device, num_samples, context_length32, use_rope_scalingFalse):# 模型文件地址ckpt torch.load(/models/lag-Llama/lag-llama.ckpt, map_locationdevice, weights_onlyFalse) # Uses GPU since in this Colab we use a GPU.estimator_args ckpt[hyper_parameters][model_kwargs]rope_scaling_arguments {type: linear,factor: max(1.0, (context_length prediction_length) / estimator_args[context_length]),}estimator LagLlamaEstimator(# 模型文件地址ckpt_path/models/lag-Llama/lag-llama.ckpt,prediction_lengthprediction_length,context_lengthcontext_length,# Lag-Llama was trained with a context length of 32, but can work with any context length# estimator argsinput_sizeestimator_args[input_size],n_layerestimator_args[n_layer],n_embd_per_headestimator_args[n_embd_per_head],n_headestimator_args[n_head],scalingestimator_args[scaling],time_featestimator_args[time_feat],rope_scalingrope_scaling_arguments if use_rope_scaling else None,batch_size1,num_parallel_samples100,devicedevice,)lightning_module estimator.create_lightning_module()transformation estimator.create_transformation()predictor estimator.create_predictor(transformation, lightning_module)forecast_it, ts_it make_evaluation_predictions(datasetdataset,predictorpredictor,num_samplesnum_samples)forecasts list(forecast_it)tss list(ts_it)return forecasts, tssimport pandas as pd
from gluonts.dataset.pandas import PandasDataseturl (/lag-llama/history.csv
)
df pd.read_csv(url, index_col0, parse_datesTrue)# Set numerical columns as float32
for col in df.columns:# Check if column is not of string typeif df[col].dtype ! object and pd.api.types.is_string_dtype(df[col]) False:df[col] df[col].astype(float32)# Create the Pandas
dataset PandasDataset.from_long_dataframe(df, targettarget, item_iditem_id)backtest_dataset dataset
# 预测长度
prediction_length 24 # Define your prediction length. We use 24 here since the data is of hourly frequency
# 样本数
num_samples 1 # number of samples sampled from the probability distribution for each timestep
device torch.device(cuda:1) # You can switch this to CPU or other GPUs if youd like, depending on your environmentforecasts, tss get_lag_llama_predictions(backtest_dataset, prediction_length, device, num_samples)# 提取第一个时间序列的预测结果
forecast forecasts[0]
print()
# 概率预测的完整样本形状: [num_samples, prediction_length]
samples forecast.samples
print(samples)
关键参数说明 参数 说明 prediction_length 预测的未来时间步长 context_length 模型输入的历史时间步长需 季节性周期 num_samples 概率预测的采样次数值越大概率区间越准 checkpoint_path 预训练模型权重路径需提前下载 freq 时间序列频率如 H 小时、D 天
结果 这里只是给出了简单的代码实现想要更好的效果还需深入研究