{ "cells": [ { "cell_type": "code", "execution_count": 15, "id": "04e8bf3c-845f-49bb-9e9c-992d6b8948f0", "metadata": {}, "outputs": [], "source": [ "# https://colab.research.google.com/drive/1XxrLW9VGPlZDw3efTvUi0hQimgJOwQG6?usp=sharing#scrollTo=gyH5Xq9eSvzq" ] }, { "cell_type": "code", "execution_count": 16, "id": "37f96736-8654-4852-a144-fd75df22aaf7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'lag-llama'...\n", "remote: Enumerating objects: 124, done.\u001b[K\n", "remote: Counting objects: 100% (69/69), done.\u001b[K\n", "remote: Compressing objects: 100% (43/43), done.\u001b[K\n", "remote: Total 124 (delta 39), reused 47 (delta 26), pack-reused 55\u001b[K\n", "Receiving objects: 100% (124/124), 190.17 KiB | 2.29 MiB/s, done.\n", "Resolving deltas: 100% (49/49), done.\n" ] } ], "source": [ "!git clone https://github.com/time-series-foundation-models/lag-llama/" ] }, { "cell_type": "code", "execution_count": 17, "id": "f5fac8fa-5ac8-4330-97e0-8a2f4237ba0f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/workspace/repos/git.d-popov.com/ai-kevin/agent-pyter/lag-llama/lag-llama\n" ] } ], "source": [ "cd ./lag-llama" ] }, { "cell_type": "code", "execution_count": 18, "id": "968625c9-00fd-4037-b97c-33dfc4758491", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: gluonts[torch] in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from -r requirements.txt (line 1)) (0.14.4)\n", "Requirement already satisfied: numpy==1.23.5 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from -r requirements.txt (line 2)) (1.23.5)\n", "Requirement already satisfied: torch>=2.0.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from -r requirements.txt (line 3)) (2.2.1)\n", "Requirement already satisfied: wandb in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from -r requirements.txt (line 4)) (0.16.4)\n", "Requirement already satisfied: scipy in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from -r requirements.txt (line 5)) (1.12.0)\n", "Requirement already satisfied: pandas==2.1.4 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from -r requirements.txt (line 6)) (2.1.4)\n", "Requirement already satisfied: huggingface_hub[cli] in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from -r requirements.txt (line 7)) (0.21.3)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from pandas==2.1.4->-r requirements.txt (line 6)) (2.9.0)\n", "Requirement already satisfied: pytz>=2020.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from pandas==2.1.4->-r requirements.txt (line 6)) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from pandas==2.1.4->-r requirements.txt (line 6)) (2024.1)\n", "Requirement already satisfied: pydantic<3,>=1.7 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from gluonts[torch]->-r requirements.txt (line 1)) (2.6.3)\n", "Requirement already satisfied: tqdm~=4.23 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from gluonts[torch]->-r requirements.txt (line 1)) (4.66.2)\n", "Requirement already satisfied: toolz~=0.10 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from gluonts[torch]->-r requirements.txt (line 1)) (0.12.1)\n", "Requirement already satisfied: typing-extensions~=4.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from gluonts[torch]->-r requirements.txt (line 1)) (4.8.0)\n", "Requirement already satisfied: lightning<2.2,>=2.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from gluonts[torch]->-r requirements.txt (line 1)) (2.1.4)\n", "Requirement already satisfied: pytorch-lightning<2.2,>=2.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from gluonts[torch]->-r requirements.txt (line 1)) (2.1.4)\n", "Requirement already satisfied: filelock in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (3.13.1)\n", "Requirement already satisfied: sympy in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (1.12)\n", "Requirement already satisfied: networkx in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (3.2.1)\n", "Requirement already satisfied: jinja2 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (3.1.2)\n", "Requirement already satisfied: fsspec in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (2024.2.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (2.19.3)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (12.1.105)\n", "Requirement already satisfied: triton==2.2.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from torch>=2.0.0->-r requirements.txt (line 3)) (2.2.0)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=2.0.0->-r requirements.txt (line 3)) (12.4.99)\n", "Requirement already satisfied: Click!=8.0.0,>=7.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from wandb->-r requirements.txt (line 4)) (8.1.7)\n", "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from wandb->-r requirements.txt (line 4)) (3.1.42)\n", "Requirement already satisfied: requests<3,>=2.0.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from wandb->-r requirements.txt (line 4)) (2.31.0)\n", "Requirement already satisfied: psutil>=5.0.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from wandb->-r requirements.txt (line 4)) (5.9.0)\n", "Requirement already satisfied: sentry-sdk>=1.0.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from wandb->-r requirements.txt (line 4)) (1.40.6)\n", "Requirement already satisfied: docker-pycreds>=0.4.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from wandb->-r requirements.txt (line 4)) (0.4.0)\n", "Requirement already satisfied: PyYAML in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from wandb->-r requirements.txt (line 4)) (6.0.1)\n", "Requirement already satisfied: setproctitle in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from wandb->-r requirements.txt (line 4)) (1.3.3)\n", "Requirement already satisfied: setuptools in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from wandb->-r requirements.txt (line 4)) (68.0.0)\n", "Requirement already satisfied: appdirs>=1.4.3 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from wandb->-r requirements.txt (line 4)) (1.4.4)\n", "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from wandb->-r requirements.txt (line 4)) (4.25.3)\n", "Requirement already satisfied: packaging>=20.9 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from huggingface_hub[cli]->-r requirements.txt (line 7)) (23.2)\n", "Requirement already satisfied: InquirerPy==0.3.4 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from huggingface_hub[cli]->-r requirements.txt (line 7)) (0.3.4)\n", "Requirement already satisfied: pfzy<0.4.0,>=0.3.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from InquirerPy==0.3.4->huggingface_hub[cli]->-r requirements.txt (line 7)) (0.3.4)\n", "Requirement already satisfied: prompt-toolkit<4.0.0,>=3.0.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from InquirerPy==0.3.4->huggingface_hub[cli]->-r requirements.txt (line 7)) (3.0.42)\n", "Requirement already satisfied: six>=1.4.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from docker-pycreds>=0.4.0->wandb->-r requirements.txt (line 4)) (1.16.0)\n", "Requirement already satisfied: gitdb<5,>=4.0.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb->-r requirements.txt (line 4)) (4.0.11)\n", "Requirement already satisfied: lightning-utilities<2.0,>=0.8.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from lightning<2.2,>=2.0->gluonts[torch]->-r requirements.txt (line 1)) (0.10.1)\n", "Requirement already satisfied: torchmetrics<3.0,>=0.7.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from lightning<2.2,>=2.0->gluonts[torch]->-r requirements.txt (line 1)) (1.3.1)\n", "Requirement already satisfied: annotated-types>=0.4.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from pydantic<3,>=1.7->gluonts[torch]->-r requirements.txt (line 1)) (0.6.0)\n", "Requirement already satisfied: pydantic-core==2.16.3 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from pydantic<3,>=1.7->gluonts[torch]->-r requirements.txt (line 1)) (2.16.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb->-r requirements.txt (line 4)) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb->-r requirements.txt (line 4)) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb->-r requirements.txt (line 4)) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb->-r requirements.txt (line 4)) (2023.7.22)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from jinja2->torch>=2.0.0->-r requirements.txt (line 3)) (2.1.3)\n", "Requirement already satisfied: mpmath>=0.19 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from sympy->torch>=2.0.0->-r requirements.txt (line 3)) (1.3.0)\n", "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from fsspec->torch>=2.0.0->-r requirements.txt (line 3)) (3.9.3)\n", "Requirement already satisfied: smmap<6,>=3.0.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb->-r requirements.txt (line 4)) (5.0.1)\n", "Requirement already satisfied: wcwidth in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from prompt-toolkit<4.0.0,>=3.0.1->InquirerPy==0.3.4->huggingface_hub[cli]->-r requirements.txt (line 7)) (0.2.13)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch>=2.0.0->-r requirements.txt (line 3)) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch>=2.0.0->-r requirements.txt (line 3)) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch>=2.0.0->-r requirements.txt (line 3)) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch>=2.0.0->-r requirements.txt (line 3)) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch>=2.0.0->-r requirements.txt (line 3)) (1.9.4)\n" ] } ], "source": [ "!pip install -r requirements.txt #--quiet # this could take some time # ignore the errors displayed by colab" ] }, { "cell_type": "code", "execution_count": 19, "id": "8f10c802-4ffa-40f7-bd62-14ff13fae03c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: requests in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (2.31.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from requests) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from requests) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from requests) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from requests) (2023.7.22)\n", "Requirement already satisfied: matplotlib in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (3.8.3)\n", "Requirement already satisfied: contourpy>=1.0.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from matplotlib) (1.2.0)\n", "Requirement already satisfied: cycler>=0.10 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from matplotlib) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from matplotlib) (4.49.0)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from matplotlib) (1.4.5)\n", "Requirement already satisfied: numpy<2,>=1.21 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from matplotlib) (1.23.5)\n", "Requirement already satisfied: packaging>=20.0 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from matplotlib) (23.2)\n", "Requirement already satisfied: pillow>=8 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from matplotlib) (10.0.1)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from matplotlib) (3.1.2)\n", "Requirement already satisfied: python-dateutil>=2.7 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from matplotlib) (2.9.0)\n", "Requirement already satisfied: six>=1.5 in /config/miniconda3/envs/pygame/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n" ] } ], "source": [ "!pip install --upgrade requests\n", "!pip install matplotlib\n" ] }, { "cell_type": "code", "execution_count": 20, "id": "0a64aa15-1477-44bc-b772-a9342a5640c8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Consider using `hf_transfer` for faster downloads. This solution comes with some limitations. See https://huggingface.co/docs/huggingface_hub/hf_transfer for more details.\n", "./lag-llama.ckpt\n" ] } ], "source": [ "!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir ./" ] }, { "cell_type": "code", "execution_count": 21, "id": "a328c513-558f-45ca-b900-b669c4ef33ed", "metadata": {}, "outputs": [], "source": [ "from itertools import islice\n", "\n", "from matplotlib import pyplot as plt\n", "import matplotlib.dates as mdates\n", "\n", "import torch\n", "from gluonts.evaluation import make_evaluation_predictions, Evaluator\n", "from gluonts.dataset.repository.datasets import get_dataset\n", "\n", "from gluonts.dataset.pandas import PandasDataset\n", "import pandas as pd\n", "\n", "from lag_llama.gluon.estimator import LagLlamaEstimator" ] }, { "cell_type": "code", "execution_count": 23, "id": "f098efb9-490c-46b7-9ea3-bea1f2871fa5", "metadata": {}, "outputs": [], "source": [ "def get_lag_llama_predictions(dataset, prediction_length, num_samples=100):\n", " ckpt = torch.load(\"lag-llama.ckpt\", map_location=torch.device('cuda:0')) # Uses GPU since in this Colab we use a GPU.\n", " estimator_args = ckpt[\"hyper_parameters\"][\"model_kwargs\"]\n", "\n", " estimator = LagLlamaEstimator(\n", " ckpt_path=\"lag-llama.ckpt\",\n", " prediction_length=prediction_length,\n", " context_length=32, # Should not be changed; this is what the released Lag-Llama model was trained with\n", "\n", " # estimator args\n", " input_size=estimator_args[\"input_size\"],\n", " n_layer=estimator_args[\"n_layer\"],\n", " n_embd_per_head=estimator_args[\"n_embd_per_head\"],\n", " n_head=estimator_args[\"n_head\"],\n", " scaling=estimator_args[\"scaling\"],\n", " time_feat=estimator_args[\"time_feat\"],\n", "\n", " batch_size=1,\n", " num_parallel_samples=100\n", " )\n", "\n", " lightning_module = estimator.create_lightning_module()\n", " transformation = estimator.create_transformation()\n", " predictor = estimator.create_predictor(transformation, lightning_module)\n", "\n", " forecast_it, ts_it = make_evaluation_predictions(\n", " dataset=dataset,\n", " predictor=predictor,\n", " num_samples=num_samples\n", " )\n", " forecasts = list(forecast_it)\n", " tss = list(ts_it)\n", "\n", " return forecasts, tss" ] }, { "cell_type": "raw", "id": "e7e6dd60-7c0c-483f-86d4-b2ba7c4104d3", "metadata": {}, "source": [ "import pandas as pd\n", "from gluonts.dataset.pandas import PandasDataset\n", "\n", "url = (\n", " \"https://gist.githubusercontent.com/rsnirwan/a8b424085c9f44ef2598da74ce43e7a3/raw/b6fdef21fe1f654787fa0493846c546b7f9c4df2/ts_long.csv\"\n", ")\n", "df = pd.read_csv(url, index_col=0, parse_dates=True)\n", "# Set numerical columns as float32\n", "for col in df.columns:\n", " # Check if column is not of string type\n", " if df[col].dtype != 'object' and pd.api.types.is_string_dtype(df[col]) == False:\n", " df[col] = df[col].astype('float32')\n", "\n", "# Create the Pandas\n", "dataset = PandasDataset.from_long_dataframe(df, target=\"target\", item_id=\"item_id\")\n", "\n", "backtest_dataset = dataset\n", "prediction_length = 24 # Define your prediction length. We use 24 here since the data is of hourly frequency\n", "num_samples = 100 # number of samples sampled from the probability distribution for each timestep\n", "forecasts, tss = get_lag_llama_predictions(backtest_dataset, prediction_length, num_samples)\n", "len(forecasts)\n", "forecasts[0].samples.shape\n", "plt.figure(figsize=(20, 15))\n", "date_formater = mdates.DateFormatter('%b, %d')\n", "plt.rcParams.update({'font.size': 15})\n", "\n", "# Iterate through the first 9 series, and plot the predicted samples\n", "for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):\n", " ax = plt.subplot(3, 3, idx+1)\n", "\n", " plt.plot(ts[-4 * prediction_length:].to_timestamp(), label=\"target\", )\n", " forecast.plot( color='g')\n", " plt.xticks(rotation=60)\n", " ax.xaxis.set_major_formatter(date_formater)\n", " ax.set_title(forecast.item_id)\n", "\n", "plt.gcf().tight_layout()\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "raw", "id": "74dc9a03-435e-40a5-bbda-4ddac9f6cfb9", "metadata": {}, "source": [ "# Set numerical columns as float32\n", "for col in df.columns:\n", " # Check if column is not of string type\n", " if df[col].dtype != 'object' and pd.api.types.is_string_dtype(df[col]) == False:\n", " df[col] = df[col].astype('float32')\n", "\n", "# Create the Pandas\n", "dataset = PandasDataset.from_long_dataframe(df, target=\"target\", item_id=\"item_id\")\n", "\n", "backtest_dataset = dataset\n", "prediction_length = 24 # Define your prediction length. We use 24 here since the data is of hourly frequency\n", "num_samples = 100 # number of samples sampled from the probability distribution for each timestep\n", "forecasts, tss = get_lag_llama_predictions(backtest_dataset, prediction_length, num_samples)\n", "len(forecasts)\n", "forecasts[0].samples.shape\n", "plt.figure(figsize=(20, 15))\n", "date_formater = mdates.DateFormatter('%b, %d')\n", "plt.rcParams.update({'font.size': 15})\n", "\n", "# Iterate through the first 9 series, and plot the predicted samples\n", "for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):\n", " ax = plt.subplot(3, 3, idx+1)\n", "\n", " plt.plot(ts[-4 * prediction_length:].to_timestamp(), label=\"target\", )\n", " forecast.plot( color='g')\n", " plt.xticks(rotation=60)\n", " ax.xaxis.set_major_formatter(date_formater)\n", " ax.set_title(forecast.item_id)\n", "\n", "plt.gcf().tight_layout()\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "b8be08b6-0cfd-45b5-ac23-142e9f388049", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }