diff --git a/examples/ptf_V2_example.ipynb b/examples/ptf_V2_example.ipynb new file mode 100644 index 000000000..2e39108f3 --- /dev/null +++ b/examples/ptf_V2_example.ipynb @@ -0,0 +1,1022 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2630DaOEI4AJ", + "outputId": "b4cc7100-2a9c-41a3-e890-164b90c91c03" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting pytorch-forecasting\n", + " Downloading pytorch_forecasting-1.3.0-py3-none-any.whl.metadata (13 kB)\n", + "Requirement already satisfied: numpy<=3.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.0.2)\n", + "Requirement already satisfied: torch!=2.0.1,<3.0.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.6.0+cu124)\n", + "Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)\n", + " Downloading lightning-2.5.1.post0-py3-none-any.whl.metadata (39 kB)\n", + "Requirement already satisfied: scipy<2.0,>=1.8 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.15.2)\n", + "Requirement already satisfied: pandas<3.0.0,>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.2.2)\n", + "Requirement already satisfied: scikit-learn<2.0,>=1.2 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.6.1)\n", + "Requirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.2)\n", + "Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2025.3.2)\n", + "Collecting lightning-utilities<2.0,>=0.10.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)\n", + "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", + "Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)\n", + "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", + "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.2)\n", + "Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading pytorch_lightning-2.5.1.post0-py3-none-any.whl.metadata (20 kB)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (3.6.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.18.0)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.6)\n", + "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-curand-cu12==10.3.5.147 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (0.6.2)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (2.21.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", + "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.2.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.11.15)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from lightning-utilities<2.0,>=0.10.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (75.2.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3.0.0,>=1.3.0->pytorch-forecasting) (1.17.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.6.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.3.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.6.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.4.3)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.3.1)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.20.0)\n", + "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", + "Downloading pytorch_forecasting-1.3.0-py3-none-any.whl (197 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning-2.5.1.post0-py3-none-any.whl (819 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m819.0/819.0 kB\u001b[0m \u001b[31m16.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m37.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m20.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m12.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m64.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)\n", + "Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m961.5/961.5 kB\u001b[0m \u001b[31m35.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pytorch_lightning-2.5.1.post0-py3-none-any.whl (823 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.1/823.1 kB\u001b[0m \u001b[31m37.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning, pytorch-forecasting\n", + " Attempting uninstall: nvidia-nvjitlink-cu12\n", + " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", + " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", + " Attempting uninstall: nvidia-curand-cu12\n", + " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", + " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", + " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", + " Attempting uninstall: nvidia-cufft-cu12\n", + " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", + " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", + " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", + " Attempting uninstall: nvidia-cuda-runtime-cu12\n", + " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", + " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-cupti-cu12\n", + " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cublas-cu12\n", + " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", + " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", + " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", + " Attempting uninstall: nvidia-cusparse-cu12\n", + " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", + " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", + " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", + " Attempting uninstall: nvidia-cudnn-cu12\n", + " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", + " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", + " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", + " Attempting uninstall: nvidia-cusolver-cu12\n", + " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", + " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", + " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", + "Successfully installed lightning-2.5.1.post0 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1.post0 torchmetrics-1.7.1\n" + ] + } + ], + "source": [ + "!pip install pytorch-forecasting" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "M7PQerTbI_tM" + }, + "outputs": [], + "source": [ + "from typing import Any, Dict, List, Optional, Tuple, Union\n", + "\n", + "from lightning.pytorch import Trainer\n", + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.preprocessing import RobustScaler, StandardScaler\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import Optimizer\n", + "from torch.utils.data import Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DGTyf3vct-Jk" + }, + "outputs": [], + "source": [ + "from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule\n", + "from pytorch_forecasting.data.encoders import (\n", + " EncoderNormalizer,\n", + " NaNLabelEncoder,\n", + " TorchNormalizer,\n", + ")\n", + "from pytorch_forecasting.data.timeseries import TimeSeries\n", + "from pytorch_forecasting.metrics import MAE, SMAPE\n", + "from pytorch_forecasting.models.temporal_fusion_transformer.tft_version_two import TFT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "WX-FRdusJSVN", + "outputId": "d162e241-3076-415c-db39-8c571bbaa282" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6702424716860947,\n \"min\": -1.2572875930191487,\n \"max\": 1.347291996576924,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.40064266948811306,\n 0.688757012378203,\n -0.9278241195910876\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6742204890063661,\n \"min\": -1.2572875930191487,\n \"max\": 1.347291996576924,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.6039073395571968,\n 0.5832480743546181,\n -0.801772762118357\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2912918986931303,\n \"min\": 0.007584244652032224,\n \"max\": 0.9959799570401108,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.4307103381570838,\n 0.6664272198589233,\n 0.16731443141739688\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", + "type": "dataframe", + "variable_name": "data_df" + }, + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
series_idtime_idxxycategoryfuture_known_featurestatic_featurestatic_feature_cat
0000.2009680.23229801.0000000.687290
1010.2322980.33666900.9950040.687290
2020.3366690.63606300.9800670.687290
3030.6360630.92771000.9553360.687290
4040.9277101.00855400.9210610.687290
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "text/plain": [ + " series_id time_idx x y category future_known_feature \\\n", + "0 0 0 0.200968 0.232298 0 1.000000 \n", + "1 0 1 0.232298 0.336669 0 0.995004 \n", + "2 0 2 0.336669 0.636063 0 0.980067 \n", + "3 0 3 0.636063 0.927710 0 0.955336 \n", + "4 0 4 0.927710 1.008554 0 0.921061 \n", + "\n", + " static_feature static_feature_cat \n", + "0 0.68729 0 \n", + "1 0.68729 0 \n", + "2 0.68729 0 \n", + "3 0.68729 0 \n", + "4 0.68729 0 " + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "\n", + "num_series = 100\n", + "seq_length = 50\n", + "data_list = []\n", + "for i in range(num_series):\n", + " x = np.arange(seq_length)\n", + " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n", + " category = i % 5\n", + " static_value = np.random.rand()\n", + " for t in range(seq_length - 1):\n", + " data_list.append(\n", + " {\n", + " \"series_id\": i,\n", + " \"time_idx\": t,\n", + " \"x\": y[t],\n", + " \"y\": y[t + 1],\n", + " \"category\": category,\n", + " \"future_known_feature\": np.cos(t / 10),\n", + " \"static_feature\": static_value,\n", + " \"static_feature_cat\": i % 3,\n", + " }\n", + " )\n", + "data_df = pd.DataFrame(data_list)\n", + "data_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "AxxPHK6AKSD2", + "outputId": "dd95173d-73c2-451b-8b67-c9cc7298cf9d" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":106: UserWarning: TimeSeries is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "dataset = TimeSeries(\n", + " data=data_df,\n", + " time=\"time_idx\",\n", + " target=\"y\",\n", + " group=[\"series_id\"],\n", + " num=[\"x\", \"future_known_feature\", \"static_feature\"],\n", + " cat=[\"category\", \"static_feature_cat\"],\n", + " known=[\"future_known_feature\"],\n", + " unknown=[\"x\", \"category\"],\n", + " static=[\"static_feature\", \"static_feature_cat\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "5U5Lr_ZFKX0s" + }, + "outputs": [], + "source": [ + "data_module = EncoderDecoderTimeSeriesDataModule(\n", + " time_series_dataset=dataset,\n", + " max_encoder_length=30,\n", + " max_prediction_length=1,\n", + " batch_size=32,\n", + " categorical_encoders={\n", + " \"category\": NaNLabelEncoder(add_nan=True),\n", + " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n", + " },\n", + " scalers={\n", + " \"x\": StandardScaler(),\n", + " \"future_known_feature\": StandardScaler(),\n", + " \"static_feature\": StandardScaler(),\n", + " },\n", + " target_normalizer=TorchNormalizer(),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "7de178cd43ab4104ba2445a057a5f1a4", + "40598800b1234eaeb769f18e3c27865c", + "05a21a6e18ca46e280a8c66d6a73cf81", + "e4a86f54cf33447f8959864533cbae72", + "d5ba9114135f4ec1818790a62beb865a", + "39ab0f799ae64a8c9c0d6041b17c0ba5", + "6b98518240b744fe8a708b37a2dcaabe", + "274c42415f834e679f45d02d7d1c01d7", + "c88fdff0012c4e4e85563adb36287774", + "c2490a4409f34a608f6073ff6b9426eb", + "1a5e8e6d619740c9b0fd752a8f886b0c", + "ceff330da01b4ed39eafa8820cb6a5ca", + "d1d794dcb83746a48628280e7d552a70", + "31e24662df144d7196e01b873ffed137", + "07c7871be3274f45a93f01c6003b35fb", + "86ab236d48d04f1980836770cfe61b0d", + "73b46d96f2d54ea8b132b409b0739588", + "0b8814890f1d4143842acce4df31d93e", + "7960939d00d844f094e1bcdc1acda7f9", + "2dcacb2bf93c4d16bfb8657670499fc5", + "ed187855e406486ca0ea2259f3e2f43c", + "c73f370b888f4893be08d51b08b23a87", + "ea890982ba5a4c3d8e15e6bbd7285f6a", + "02f4487acf82401b972b4593da15de15", + "cca521da65e946abad580e9db9f2ac6b", + "b4c9075a5c1148ee8e422b2fcf86d90b", + "57dd80719a2e444cab8345bf5086a2f7", + "14e2d31e04e9447ba4f0c55000e0abe2", + "b35a713d4d7041109db487d10c55aaed", + "59ca144ddc884bb5a1c038e96cdf0dc0", + "95f9b03446af41fb83876f625aae5d75", + "bd2154cf3b04468b93a3bec23b5e34fa", + "f4713a710f274dd389691d7c12a7e740", + "07571714d67e4b8793ee76b0fe151e67", + "4e7273caa91147019971bf75ffe71e49", + "b751dbe4e93341eaa7d1a7683c277d83", + "11e1c349cb894caa9fd77333f7ababb9", + "8c9ac67ac6af488fb16e64c834634a30", + "704f28911d674088b8dfc240c6e28449", + "67da733b7a254020ad4d8d3877eb4494", + "72858abca77d430a8d009ad72127b331", + "86e82ab383a54ef7aa7c0abec6faf1be", + "60f097a304044e3db6dda60ff381776a", + "334e7a44226f4414a825e2d81e9571c1", + "13246130b24145d680092a5a3929546e", + "e64a93136fbf468b8e503cd202dbc986", + "559df10442ac42fc8033bdf014864334", + "67e6440ba0424db18646d47beff2e37b", + "36707c56fcfb4fadbfff37d44ee52d5b", + "484702bf9d854cb6a964de47d6975aa6", + "ddb67e4d1749424299a2c40b36810809", + "6f1e8c2aa06a48548a7750869d3f056e", + "d7df1875fbbb4c78909d76c5b81ffd95", + "cb60f83efc234cd6a90212125fe841a1", + "dfca7d9f0f7844ea953ce4b659493695", + "14ed4e565d5a4b319b0c38557a935b92", + "48e7f65e9dcc48289f44724da248875d", + "8220ff572b31472e91dc5d7553ca43b9", + "fcd1b8b77a7f4c108ba02de9779f5bb2", + "a53e4183893242ea884db9f25e439d94", + "e54fdc83afdf492a9540bc892bdc262a", + "04ce55b4c81a46b09d55b8a0dc3fae00", + "86eddf636c2b45fc93a69f3c8a260b1c", + "287ce5b9bafb402e9698f20779f52386", + "ef80c0b7306046738cfd4e16f5af1afc", + "c0bb5f38ec9346119441c6cb8fd71c5a", + "a54b4228eaf34dc1b40ee8d40500e069", + "28cb77e3b48a4c669dcaa79609d332c2", + "8acc35455cb84ece9e7b3d84d8870a7a", + "afe9a01608f449479cc491f75e095d55", + "9b5971f9e8d44e8a872b782ef49d306a", + "1538291371bc45b2bcb7a687c6b8f79c", + "d595a038c20e4d349e392ff1795dd418", + "a67fa2461a18428892e57790074fd5b6", + "6733a54e084a4d348a85e574985720b7", + "204594abf12e43d4abf0fa35544fb64a", + "eee2d2ef09294652bd94ff768ffb99f0", + "c8b8814037464155935670b776378ab9", + "0b6b787693af4b548ad59ad7ba6c921f", + "6b7172a2b1fd4ddea3dba3d435bf36fe", + "ed0ffcd137e94edebad5e41956ad7466", + "0344079173a84c448cd6edbac612d970", + "890f17882dc44e76be06dd88e6d07cff", + "1c552ec98dcd4325800ee8c9dddc398a", + "89d54947a3d146e4b089a22a2711a234", + "91e7ef398d1646bcbfb598480d949c74", + "13173089480c49b18c4ba016da69a855", + "6b717d2759ca445da7642a2cf20f22b5" + ] + }, + "id": "Si7bbZIULBZz", + "outputId": "0b2f26b3-e37c-4ab6-c234-90693745a8cd" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", + "INFO:lightning.pytorch.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n", + "INFO: GPU available: False, used: False\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: False, used: False\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", + "INFO: \n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 193 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 50.4 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n", + "INFO:lightning.pytorch.callbacks.model_summary:\n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | loss | MSELoss | 0 | train\n", + "1 | encoder_var_selection | Sequential | 709 | train\n", + "2 | decoder_var_selection | Sequential | 193 | train\n", + "3 | static_context_linear | Linear | 192 | train\n", + "4 | lstm_encoder | LSTM | 51.5 K | train\n", + "5 | lstm_decoder | LSTM | 50.4 K | train\n", + "6 | self_attention | MultiheadAttention | 16.6 K | train\n", + "7 | pre_output | Linear | 4.2 K | train\n", + "8 | output_layer | Linear | 65 | train\n", + "---------------------------------------------------------------------\n", + "123 K Trainable params\n", + "0 Non-trainable params\n", + "123 K Total params\n", + "0.495 Total estimated model params size (MB)\n", + "18 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Training model...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7de178cd43ab4104ba2445a057a5f1a4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test_MAE 0.45287469029426575 │\n", + "│ test_SMAPE 0.942494809627533 │\n", + "│ test_loss 0.01396977063268423 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.45287469029426575 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.942494809627533 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.01396977063268423 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Prediction shape: torch.Size([32, 1, 1])\n", + "First prediction values: [[0.11045122]]\n", + "First true values: [[-0.0491814]]\n", + "\n", + "TFT model test complete!\n" + ] + } + ], + "source": [ + "model = TFT(\n", + " loss=nn.MSELoss(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " optimizer=\"adam\",\n", + " optimizer_params={\"lr\": 1e-3},\n", + " lr_scheduler=\"reduce_lr_on_plateau\",\n", + " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", + " hidden_size=64,\n", + " num_layers=2,\n", + " attention_head_size=4,\n", + " dropout=0.1,\n", + " metadata=data_module.metadata,\n", + ")\n", + "\n", + "print(\"\\nTraining model...\")\n", + "trainer = Trainer(\n", + " max_epochs=5,\n", + " accelerator=\"auto\",\n", + " devices=1,\n", + " enable_progress_bar=True,\n", + " log_every_n_steps=10,\n", + ")\n", + "\n", + "trainer.fit(model, data_module)\n", + "\n", + "print(\"\\nEvaluating model...\")\n", + "test_metrics = trainer.test(model, data_module)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " test_batch = next(iter(data_module.test_dataloader()))\n", + " x_test, y_test = test_batch\n", + " y_pred = model(x_test)\n", + "\n", + " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", + " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", + " print(\"First true values:\", y_test[0].cpu().numpy())\n", + "print(\"\\nTFT model test complete!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zVRwi2MvLGgc" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/pytorch_forecasting/models/base/_base_model_v2.py b/pytorch_forecasting/models/base/_base_model_v2.py new file mode 100644 index 000000000..ddefc29fb --- /dev/null +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -0,0 +1,296 @@ +######################################################################################## +# Disclaimer: This baseclass is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the base classes may look like +# in the version-2. +######################################################################################## + + +from typing import Dict, List, Optional, Tuple, Union +from warnings import warn + +from lightning.pytorch import LightningModule +from lightning.pytorch.utilities.types import STEP_OUTPUT +import torch +import torch.nn as nn +from torch.optim import Optimizer + + +class BaseModel(LightningModule): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + ): + """ + Base model for time series forecasting. + + Parameters + ---------- + loss : nn.Module + Loss function to use for training. + logging_metrics : Optional[List[nn.Module]], optional + List of metrics to log during training, validation, and testing. + optimizer : Optional[Union[Optimizer, str]], optional + Optimizer to use for training. + Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`. + optimizer_params : Optional[Dict], optional + Parameters for the optimizer. + lr_scheduler : Optional[str], optional + Learning rate scheduler to use. + Supported values: "reduce_lr_on_plateau", "step_lr". + lr_scheduler_params : Optional[Dict], optional + Parameters for the learning rate scheduler. + """ + super().__init__() + self.loss = loss + self.logging_metrics = logging_metrics if logging_metrics is not None else [] + self.optimizer = optimizer + self.optimizer_params = optimizer_params if optimizer_params is not None else {} + self.lr_scheduler = lr_scheduler + self.lr_scheduler_params = ( + lr_scheduler_params if lr_scheduler_params is not None else {} + ) + self.model_name = self.__class__.__name__ + warn( + f"The Model '{self.model_name}' is part of an experimental rework" + "of the pytorch-forecasting model layer, scheduled for release with v2.0.0." + " The API is not stable and may change without prior warning. " + "This class is intended for beta testing and as a basic skeleton, " + "but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors + """ + raise NotImplementedError("Forward method must be implemented by subclass.") + + def training_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Training step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="train") + return {"loss": loss} + + def validation_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Validation step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="val") + return {"val_loss": loss} + + def test_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Test step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="test") + return {"test_loss": loss} + + def predict_step( + self, + batch: Tuple[Dict[str, torch.Tensor]], + batch_idx: int, + dataloader_idx: int = 0, + ) -> torch.Tensor: + """ + Prediction step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input tensors. + batch_idx : int + Index of the batch. + dataloader_idx : int + Index of the dataloader. + + Returns + ------- + torch.Tensor + Predicted output tensor. + """ + x, _ = batch + y_hat = self(x) + return y_hat + + def configure_optimizers(self) -> Dict: + """ + Configure the optimizer and learning rate scheduler. + + Returns + ------- + Dict + Dictionary containing the optimizer and scheduler configuration. + """ + optimizer = self._get_optimizer() + if self.lr_scheduler is not None: + scheduler = self._get_scheduler(optimizer) + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + }, + } + else: + return {"optimizer": optimizer, "lr_scheduler": scheduler} + return {"optimizer": optimizer} + + def _get_optimizer(self) -> Optimizer: + """ + Get the optimizer based on the specified optimizer name and parameters. + + Returns + ------- + Optimizer + The optimizer instance. + """ + if isinstance(self.optimizer, str): + if self.optimizer.lower() == "adam": + return torch.optim.Adam(self.parameters(), **self.optimizer_params) + elif self.optimizer.lower() == "sgd": + return torch.optim.SGD(self.parameters(), **self.optimizer_params) + else: + raise ValueError(f"Optimizer {self.optimizer} not supported.") + elif isinstance(self.optimizer, Optimizer): + return self.optimizer + else: + raise ValueError( + "Optimizer must be either a string or " + "an instance of torch.optim.Optimizer." + ) + + def _get_scheduler( + self, optimizer: Optimizer + ) -> torch.optim.lr_scheduler._LRScheduler: + """ + Get the lr scheduler based on the specified scheduler name and params. + + Parameters + ---------- + optimizer : Optimizer + The optimizer instance. + + Returns + ------- + torch.optim.lr_scheduler._LRScheduler + The learning rate scheduler instance. + """ + if self.lr_scheduler.lower() == "reduce_lr_on_plateau": + return torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, **self.lr_scheduler_params + ) + elif self.lr_scheduler.lower() == "step_lr": + return torch.optim.lr_scheduler.StepLR( + optimizer, **self.lr_scheduler_params + ) + else: + raise ValueError(f"Scheduler {self.lr_scheduler} not supported.") + + def log_metrics( + self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = "val" + ) -> None: + """ + Log additional metrics during training, validation, or testing. + + Parameters + ---------- + y_hat : torch.Tensor + Predicted output tensor. + y : torch.Tensor + Target output tensor. + prefix : str + Prefix for the logged metrics (e.g., "train", "val", "test"). + """ + for metric in self.logging_metrics: + metric_value = metric(y_hat, y) + self.log( + f"{prefix}_{metric.__class__.__name__}", + metric_value, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py new file mode 100644 index 000000000..a0cf7d39e --- /dev/null +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py @@ -0,0 +1,253 @@ +######################################################################################## +# Disclaimer: This implementation is based on the new version of data pipeline and is +# experimental, please use with care. +######################################################################################## + +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch.optim import Optimizer + +from pytorch_forecasting.models.base._base_model_v2 import BaseModel + + +class TFT(BaseModel): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + hidden_size: int = 64, + num_layers: int = 2, + attention_head_size: int = 4, + dropout: float = 0.1, + metadata: Optional[Dict] = None, + output_size: int = 1, + ): + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + optimizer=optimizer, + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + lr_scheduler_params=lr_scheduler_params, + ) + self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) + + self.hidden_size = hidden_size + self.num_layers = num_layers + self.attention_head_size = attention_head_size + self.dropout = dropout + self.metadata = metadata + self.output_size = output_size + + self.max_encoder_length = self.metadata["max_encoder_length"] + self.max_prediction_length = self.metadata["max_prediction_length"] + self.encoder_cont = self.metadata["encoder_cont"] + self.encoder_cat = self.metadata["encoder_cat"] + self.encoder_input_dim = self.encoder_cont + self.encoder_cat + self.decoder_cont = self.metadata["decoder_cont"] + self.decoder_cat = self.metadata["decoder_cat"] + self.decoder_input_dim = self.decoder_cont + self.decoder_cat + self.static_cat_dim = self.metadata.get("static_categorical_features", 0) + self.static_cont_dim = self.metadata.get("static_continuous_features", 0) + self.static_input_dim = self.static_cat_dim + self.static_cont_dim + + if self.encoder_input_dim > 0: + self.encoder_var_selection = nn.Sequential( + nn.Linear(self.encoder_input_dim, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, self.encoder_input_dim), + nn.Sigmoid(), + ) + else: + self.encoder_var_selection = None + + if self.decoder_input_dim > 0: + self.decoder_var_selection = nn.Sequential( + nn.Linear(self.decoder_input_dim, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, self.decoder_input_dim), + nn.Sigmoid(), + ) + else: + self.decoder_var_selection = None + + if self.static_input_dim > 0: + self.static_context_linear = nn.Linear(self.static_input_dim, hidden_size) + else: + self.static_context_linear = None + + _lstm_encoder_input_actual_dim = self.encoder_input_dim + self.lstm_encoder = nn.LSTM( + input_size=max(1, _lstm_encoder_input_actual_dim), + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + _lstm_decoder_input_actual_dim = self.decoder_input_dim + self.lstm_decoder = nn.LSTM( + input_size=max(1, _lstm_decoder_input_actual_dim), + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + self.self_attention = nn.MultiheadAttention( + embed_dim=hidden_size, + num_heads=attention_head_size, + dropout=dropout, + batch_first=True, + ) + + self.pre_output = nn.Linear(hidden_size, hidden_size) + self.output_layer = nn.Linear(hidden_size, self.output_size) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the TFT model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors: + - encoder_cat: Categorical encoder features + - encoder_cont: Continuous encoder features + - decoder_cat: Categorical decoder features + - decoder_cont: Continuous decoder features + - static_categorical_features: Static categorical features + - static_continuous_features: Static continuous features + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors: + - prediction: Prediction output (batch_size, prediction_length, output_size) + """ + batch_size = x["encoder_cont"].shape[0] + + encoder_cat = x.get( + "encoder_cat", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + encoder_cont = x.get( + "encoder_cont", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + decoder_cat = x.get( + "decoder_cat", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + decoder_cont = x.get( + "decoder_cont", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + + encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2) + decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2) + + static_context = None + if self.static_context_linear is not None: + static_cat = x.get( + "static_categorical_features", + torch.zeros(batch_size, 1, 0, device=self.device), + ) + static_cont = x.get( + "static_continuous_features", + torch.zeros(batch_size, 1, 0, device=self.device), + ) + + if static_cat.size(2) == 0 and static_cont.size(2) == 0: + static_context = None + elif static_cat.size(2) == 0: + static_input = static_cont.to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + elif static_cont.size(2) == 0: + static_input = static_cat.to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + else: + + static_input = torch.cat([static_cont, static_cat], dim=2).to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + + if self.encoder_var_selection is not None: + encoder_weights = self.encoder_var_selection(encoder_input) + encoder_input = encoder_input * encoder_weights + else: + if self.encoder_input_dim == 0: + encoder_input = torch.zeros( + batch_size, + self.max_encoder_length, + 1, + device=self.device, + dtype=encoder_input.dtype, + ) + else: + encoder_input = encoder_input + + if self.decoder_var_selection is not None: + decoder_weights = self.decoder_var_selection(decoder_input) + decoder_input = decoder_input * decoder_weights + else: + if self.decoder_input_dim == 0: + decoder_input = torch.zeros( + batch_size, + self.max_prediction_length, + 1, + device=self.device, + dtype=decoder_input.dtype, + ) + else: + decoder_input = decoder_input + + if static_context is not None: + encoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_encoder_length, -1 + ) + decoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_prediction_length, -1 + ) + + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + encoder_output = encoder_output + encoder_static_context + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + decoder_output = decoder_output + decoder_static_context + else: + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + + sequence = torch.cat([encoder_output, decoder_output], dim=1) + + if static_context is not None: + expanded_static_context = static_context.unsqueeze(1).expand( + -1, sequence.size(1), -1 + ) + + attended_output, _ = self.self_attention( + sequence + expanded_static_context, sequence, sequence + ) + else: + attended_output, _ = self.self_attention(sequence, sequence, sequence) + + decoder_attended = attended_output[:, -self.max_prediction_length :, :] + + output = nn.functional.relu(self.pre_output(decoder_attended)) + prediction = self.output_layer(output) + + return {"prediction": prediction} diff --git a/tests/test_models/_test_tft_v2.py b/tests/test_models/_test_tft_v2.py new file mode 100644 index 000000000..13d92d5db --- /dev/null +++ b/tests/test_models/_test_tft_v2.py @@ -0,0 +1,398 @@ +import numpy as np +import pandas as pd +import pytest +import torch +import torch.nn as nn + +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule +from pytorch_forecasting.data.timeseries import TimeSeries +from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT + +BATCH_SIZE_TEST = 2 +MAX_ENCODER_LENGTH_TEST = 10 +MAX_PREDICTION_LENGTH_TEST = 5 +HIDDEN_SIZE_TEST = 8 +OUTPUT_SIZE_TEST = 1 +ATTENTION_HEAD_SIZE_TEST = 2 +NUM_LAYERS_TEST = 1 +DROPOUT_TEST = 0.1 + + +def get_default_test_metadata( + enc_cont=2, + enc_cat=1, + dec_cont=1, + dec_cat=1, + static_cat=1, + static_cont=1, + output_size=OUTPUT_SIZE_TEST, +): + """Return a dict representing default metadata for TFT model initialization.""" + return { + "max_encoder_length": MAX_ENCODER_LENGTH_TEST, + "max_prediction_length": MAX_PREDICTION_LENGTH_TEST, + "encoder_cont": enc_cont, + "encoder_cat": enc_cat, + "decoder_cont": dec_cont, + "decoder_cat": dec_cat, + "static_categorical_features": static_cat, + "static_continuous_features": static_cont, + "target": output_size, + } + + +def create_tft_input_batch_for_test(metadata, batch_size=BATCH_SIZE_TEST, device="cpu"): + """Create a synthetic input batch dictionary for testing TFT forward passes.""" + + def _get_dim_val(key): + return metadata.get(key, 0) + + x = { + "encoder_cont": torch.randn( + batch_size, + metadata["max_encoder_length"], + _get_dim_val("encoder_cont"), + device=device, + ), + "encoder_cat": torch.randn( + batch_size, + metadata["max_encoder_length"], + _get_dim_val("encoder_cat"), + device=device, + ), + "decoder_cont": torch.randn( + batch_size, + metadata["max_prediction_length"], + _get_dim_val("decoder_cont"), + device=device, + ), + "decoder_cat": torch.randn( + batch_size, + metadata["max_prediction_length"], + _get_dim_val("decoder_cat"), + device=device, + ), + "static_categorical_features": torch.randn( + batch_size, 1, _get_dim_val("static_categorical_features"), device=device + ), + "static_continuous_features": torch.randn( + batch_size, 1, _get_dim_val("static_continuous_features"), device=device + ), + "encoder_lengths": torch.full( + (batch_size,), + metadata["max_encoder_length"], + dtype=torch.long, + device=device, + ), + "decoder_lengths": torch.full( + (batch_size,), + metadata["max_prediction_length"], + dtype=torch.long, + device=device, + ), + "groups": torch.arange(batch_size, device=device).unsqueeze(1), + "encoder_time_idx": torch.stack( + [torch.arange(metadata["max_encoder_length"], device=device)] * batch_size + ), + "decoder_time_idx": torch.stack( + [ + torch.arange( + metadata["max_encoder_length"], + metadata["max_encoder_length"] + metadata["max_prediction_length"], + device=device, + ) + ] + * batch_size + ), + "target_scale": torch.ones((batch_size, 1), device=device), + } + return x + + +dummy_loss_for_test = nn.MSELoss() + + +@pytest.fixture(scope="module") +def tft_model_params_fixture_func(): + """Create a default set of model parameters for TFT.""" + return { + "loss": dummy_loss_for_test, + "hidden_size": HIDDEN_SIZE_TEST, + "num_layers": NUM_LAYERS_TEST, + "attention_head_size": ATTENTION_HEAD_SIZE_TEST, + "dropout": DROPOUT_TEST, + "output_size": OUTPUT_SIZE_TEST, + } + + +def test_basic_initialization(tft_model_params_fixture_func): + """Test basic initialization of the TFT model with default metadata. + + Verifies: + - Model attributes match the provided metadata (e.g., hidden_size, num_layers). + - Proper construction of key model components (LSTM, attention, etc.). + - Correct dimensionality of input layers based on metadata. + - Model retains metadata and hyperparameters as expected. + """ + metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.hidden_size == HIDDEN_SIZE_TEST + assert model.num_layers == NUM_LAYERS_TEST + assert hasattr(model, "metadata") and model.metadata == metadata + assert model.encoder_input_dim == metadata["encoder_cont"] + metadata["encoder_cat"] + assert ( + model.static_input_dim + == metadata["static_categorical_features"] + + metadata["static_continuous_features"] + ) + assert isinstance(model.lstm_encoder, nn.LSTM) + assert model.lstm_encoder.input_size == max(1, model.encoder_input_dim) + assert isinstance(model.self_attention, nn.MultiheadAttention) + if hasattr(model, "hparams") and model.hparams: + assert model.hparams.get("hidden_size") == HIDDEN_SIZE_TEST + assert model.output_size == OUTPUT_SIZE_TEST + + +def test_initialization_no_time_varying_features(tft_model_params_fixture_func): + """Test TFT initialization with no time-varying (encoder/decoder) features. + + Verifies: + - Model handles zero encoder/decoder input dimensions correctly. + - Skips creation of encoder/decoder variable selection networks. + - Defaults to input size 1 for LSTMs when no time-varying features exist. + """ + metadata = get_default_test_metadata( + enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.encoder_input_dim == 0 + assert model.encoder_var_selection is None + assert model.lstm_encoder.input_size == 1 + assert model.decoder_input_dim == 0 + assert model.decoder_var_selection is None + assert model.lstm_decoder.input_size == 1 + + +def test_initialization_no_static_features(tft_model_params_fixture_func): + """Test TFT initialization with no static features. + + Verifies: + - Model static input dim is 0. + - Static context linear layer is not created. + """ + metadata = get_default_test_metadata( + static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.static_input_dim == 0 + assert model.static_context_linear is None + + +@pytest.mark.parametrize( + "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", + [ + (2, 1, 1, 1, 1, 1), + (2, 0, 1, 0, 0, 0), + (0, 0, 0, 0, 1, 1), + (0, 0, 0, 0, 0, 0), + (1, 0, 1, 0, 1, 0), + (1, 0, 1, 0, 0, 1), + ], +) +def test_forward_pass_configs( + tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k +): + """Test TFT forward pass across multiple feature configurations. + + Verifies: + - Model can forward pass without errors for varying combinations of input types. + - Output prediction tensor has expected shape. + - Output contains no NaNs or infinities. + """ + current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] + metadata = get_default_test_metadata( + enc_cont=enc_c, + enc_cat=enc_k, + dec_cont=dec_c, + dec_cat=dec_k, + static_cat=stat_c, + static_cont=stat_k, + output_size=current_tft_actual_output_size, + ) + model_params = tft_model_params_fixture_func.copy() + model_params["output_size"] = current_tft_actual_output_size + model = TFT(**model_params, metadata=metadata) + model.eval() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + x = create_tft_input_batch_for_test( + metadata, batch_size=BATCH_SIZE_TEST, device=device + ) + output_dict = model(x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + BATCH_SIZE_TEST, + MAX_PREDICTION_LENGTH_TEST, + current_tft_actual_output_size, + ) + assert not torch.isnan(predictions).any(), "NaNs in prediction" + assert not torch.isinf(predictions).any(), "Infs in prediction" + + +@pytest.fixture +def sample_pandas_data_for_test(): + """Create synthetic multivariate time series data as a pandas DataFrame.""" + series_len = MAX_ENCODER_LENGTH_TEST + MAX_PREDICTION_LENGTH_TEST + 5 + num_groups = 6 + data = [] + + for i in range(num_groups): + static_cont_val = np.float32(i * 10.0) + static_cat_code = np.float32(i % 2) + + df_group = pd.DataFrame( + { + "time_idx": np.arange(series_len, dtype=np.int64), + "group_id_str": np.repeat(f"g{i}", series_len), + "target": np.random.rand(series_len).astype(np.float32) + i, + "enc_cont1": np.random.rand(series_len).astype(np.float32), + "enc_cat1_codes": np.random.randint(0, 3, series_len).astype( + np.float32 + ), + "dec_known_cont": np.sin(np.arange(series_len) / 5.0).astype( + np.float32 + ), + "dec_known_cat_codes": np.random.randint(0, 2, series_len).astype( + np.float32 + ), + "static_cat_feat_codes": np.full( + series_len, static_cat_code, dtype=np.float32 + ), + "static_cont_feat": np.full( + series_len, static_cont_val, dtype=np.float32 + ), + } + ) + data.append(df_group) + + df = pd.concat(data, ignore_index=True) + + df["group_id"] = df["group_id_str"].astype("category") + df.drop(columns=["group_id_str"], inplace=True) + + return df + + +@pytest.fixture +def timeseries_obj_for_test(sample_pandas_data_for_test): + """Convert sample DataFrame into a TimeSeries object.""" + df = sample_pandas_data_for_test + + return TimeSeries( + data=df, + time="time_idx", + target="target", + group=["group_id"], + num=[ + "enc_cont1", + "enc_cat1_codes", + "dec_known_cont", + "dec_known_cat_codes", + "static_cat_feat_codes", + "static_cont_feat", + ], + cat=[], + known=["dec_known_cont", "dec_known_cat_codes", "time_idx"], + static=["static_cat_feat_codes", "static_cont_feat"], + ) + + +@pytest.fixture +def data_module_for_test(timeseries_obj_for_test): + """Initialize and sets up an EncoderDecoderTimeSeriesDataModule.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=timeseries_obj_for_test, + batch_size=BATCH_SIZE_TEST, + max_encoder_length=MAX_ENCODER_LENGTH_TEST, + max_prediction_length=MAX_PREDICTION_LENGTH_TEST, + train_val_test_split=(0.5, 0.25, 0.25), + ) + dm.setup("fit") + dm.setup("test") + return dm + + +def test_model_with_datamodule_integration( + tft_model_params_fixture_func, data_module_for_test +): + """Integration test to ensure TFT works correctly with data module. + + Verifies: + - Metadata inferred from data module matches expected input dimensions. + - Model processes real dataloader batches correctly. + - Output and target tensors from model and data module align in shape. + - No NaNs in predictions. + """ + dm = data_module_for_test + model_metadata_from_dm = dm.metadata + + assert ( + model_metadata_from_dm["encoder_cont"] == 6 + ), f"Actual encoder_cont: {model_metadata_from_dm['encoder_cont']}" + assert ( + model_metadata_from_dm["encoder_cat"] == 0 + ), f"Actual encoder_cat: {model_metadata_from_dm['encoder_cat']}" + assert ( + model_metadata_from_dm["decoder_cont"] == 2 + ), f"Actual decoder_cont: {model_metadata_from_dm['decoder_cont']}" + assert ( + model_metadata_from_dm["decoder_cat"] == 0 + ), f"Actual decoder_cat: {model_metadata_from_dm['decoder_cat']}" + assert ( + model_metadata_from_dm["static_categorical_features"] == 0 + ), f"Actual static_cat: {model_metadata_from_dm['static_categorical_features']}" + assert ( + model_metadata_from_dm["static_continuous_features"] == 2 + ), f"Actual static_cont: {model_metadata_from_dm['static_continuous_features']}" + assert model_metadata_from_dm["target"] == 1 + + tft_init_args = tft_model_params_fixture_func.copy() + tft_init_args["output_size"] = model_metadata_from_dm["target"] + model = TFT(**tft_init_args, metadata=model_metadata_from_dm) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + + train_loader = dm.train_dataloader() + batch_x, batch_y = next(iter(train_loader)) + + actual_batch_size = batch_x["encoder_cont"].shape[0] + batch_x = {k: v.to(device) for k, v in batch_x.items()} + batch_y = batch_y.to(device) + + assert batch_x["encoder_cont"].shape[2] == model_metadata_from_dm["encoder_cont"] + assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] + assert batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] + assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] + assert ( + batch_x["static_categorical_features"].shape[2] + == model_metadata_from_dm["static_categorical_features"] + ) + assert ( + batch_x["static_continuous_features"].shape[2] + == model_metadata_from_dm["static_continuous_features"] + ) + + output_dict = model(batch_x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) + assert not torch.isnan(predictions).any() + assert batch_y.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + )