diff --git a/examples/ptf_V2_example.ipynb b/examples/ptf_V2_example.ipynb new file mode 100644 index 000000000..2e39108f3 --- /dev/null +++ b/examples/ptf_V2_example.ipynb @@ -0,0 +1,1022 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2630DaOEI4AJ", + "outputId": "b4cc7100-2a9c-41a3-e890-164b90c91c03" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting pytorch-forecasting\n", + " Downloading pytorch_forecasting-1.3.0-py3-none-any.whl.metadata (13 kB)\n", + "Requirement already satisfied: numpy<=3.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.0.2)\n", + "Requirement already satisfied: torch!=2.0.1,<3.0.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.6.0+cu124)\n", + "Collecting lightning<3.0.0,>=2.0.0 (from pytorch-forecasting)\n", + " Downloading lightning-2.5.1.post0-py3-none-any.whl.metadata (39 kB)\n", + "Requirement already satisfied: scipy<2.0,>=1.8 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.15.2)\n", + "Requirement already satisfied: pandas<3.0.0,>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (2.2.2)\n", + "Requirement already satisfied: scikit-learn<2.0,>=1.2 in /usr/local/lib/python3.11/dist-packages (from pytorch-forecasting) (1.6.1)\n", + "Requirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.2)\n", + "Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2025.3.2)\n", + "Collecting lightning-utilities<2.0,>=0.10.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)\n", + "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (24.2)\n", + "Collecting torchmetrics<3.0,>=0.7.0 (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)\n", + "Requirement already satisfied: tqdm<6.0,>=4.57.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.67.1)\n", + "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.11/dist-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.13.2)\n", + "Collecting pytorch-lightning (from lightning<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading pytorch_lightning-2.5.1.post0-py3-none-any.whl.metadata (20 kB)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0.0,>=1.3.0->pytorch-forecasting) (2025.2)\n", + "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (3.6.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.18.0)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.6)\n", + "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-curand-cu12==10.3.5.147 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", + "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (0.6.2)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (2.21.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (12.4.127)\n", + "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting)\n", + " Downloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.2.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.11.15)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from lightning-utilities<2.0,>=0.10.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (75.2.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3.0.0,>=1.3.0->pytorch-forecasting) (1.17.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch!=2.0.1,<3.0.0,>=2.0.0->pytorch-forecasting) (3.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.6.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.2)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (25.3.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.6.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.4.3)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.3.1)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.20.0)\n", + "Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.11/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2026.0,>=2022.5.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.10)\n", + "Downloading pytorch_forecasting-1.3.0-py3-none-any.whl (197 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m197.7/197.7 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning-2.5.1.post0-py3-none-any.whl (819 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m819.0/819.0 kB\u001b[0m \u001b[31m16.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m37.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m20.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m12.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m64.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)\n", + "Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m961.5/961.5 kB\u001b[0m \u001b[31m35.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pytorch_lightning-2.5.1.post0-py3-none-any.whl (823 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.1/823.1 kB\u001b[0m \u001b[31m37.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, lightning-utilities, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12, torchmetrics, pytorch-lightning, lightning, pytorch-forecasting\n", + " Attempting uninstall: nvidia-nvjitlink-cu12\n", + " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", + " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", + " Attempting uninstall: nvidia-curand-cu12\n", + " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", + " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", + " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", + " Attempting uninstall: nvidia-cufft-cu12\n", + " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", + " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", + " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", + " Attempting uninstall: nvidia-cuda-runtime-cu12\n", + " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", + " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cuda-cupti-cu12\n", + " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", + " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", + " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", + " Attempting uninstall: nvidia-cublas-cu12\n", + " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", + " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", + " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", + " Attempting uninstall: nvidia-cusparse-cu12\n", + " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", + " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", + " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", + " Attempting uninstall: nvidia-cudnn-cu12\n", + " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", + " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", + " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", + " Attempting uninstall: nvidia-cusolver-cu12\n", + " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", + " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", + " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", + "Successfully installed lightning-2.5.1.post0 lightning-utilities-0.14.3 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 pytorch-forecasting-1.3.0 pytorch-lightning-2.5.1.post0 torchmetrics-1.7.1\n" + ] + } + ], + "source": [ + "!pip install pytorch-forecasting" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "M7PQerTbI_tM" + }, + "outputs": [], + "source": [ + "from typing import Any, Dict, List, Optional, Tuple, Union\n", + "\n", + "from lightning.pytorch import Trainer\n", + "import numpy as np\n", + "import pandas as pd\n", + "from sklearn.preprocessing import RobustScaler, StandardScaler\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import Optimizer\n", + "from torch.utils.data import Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DGTyf3vct-Jk" + }, + "outputs": [], + "source": [ + "from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule\n", + "from pytorch_forecasting.data.encoders import (\n", + " EncoderNormalizer,\n", + " NaNLabelEncoder,\n", + " TorchNormalizer,\n", + ")\n", + "from pytorch_forecasting.data.timeseries import TimeSeries\n", + "from pytorch_forecasting.metrics import MAE, SMAPE\n", + "from pytorch_forecasting.models.temporal_fusion_transformer.tft_version_two import TFT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "WX-FRdusJSVN", + "outputId": "d162e241-3076-415c-db39-8c571bbaa282" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "summary": "{\n \"name\": \"data_df\",\n \"rows\": 4900,\n \"fields\": [\n {\n \"column\": \"series_id\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 28,\n \"min\": 0,\n \"max\": 99,\n \"num_unique_values\": 100,\n \"samples\": [\n 83,\n 53,\n 70\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"time_idx\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 14,\n \"min\": 0,\n \"max\": 48,\n \"num_unique_values\": 49,\n \"samples\": [\n 13,\n 45,\n 47\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"x\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6702424716860947,\n \"min\": -1.2572875930191487,\n \"max\": 1.347291996576924,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.40064266948811306,\n 0.688757012378203,\n -0.9278241195910876\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"y\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6742204890063661,\n \"min\": -1.2572875930191487,\n \"max\": 1.347291996576924,\n \"num_unique_values\": 4900,\n \"samples\": [\n 0.6039073395571968,\n 0.5832480743546181,\n -0.801772762118357\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"category\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 4,\n \"num_unique_values\": 5,\n \"samples\": [\n 1,\n 4,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"future_known_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.6741140972121411,\n \"min\": -0.9991351502732795,\n \"max\": 1.0,\n \"num_unique_values\": 49,\n \"samples\": [\n 0.26749882862458735,\n -0.2107957994307797,\n -0.01238866346289056\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.2912918986931303,\n \"min\": 0.007584244652032224,\n \"max\": 0.9959799570401108,\n \"num_unique_values\": 100,\n \"samples\": [\n 0.4307103381570838,\n 0.6664272198589233,\n 0.16731443141739688\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"static_feature_cat\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0,\n \"min\": 0,\n \"max\": 2,\n \"num_unique_values\": 3,\n \"samples\": [\n 0,\n 1,\n 2\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}", + "type": "dataframe", + "variable_name": "data_df" + }, + "text/html": [ + "\n", + "
\n", + " | series_id | \n", + "time_idx | \n", + "x | \n", + "y | \n", + "category | \n", + "future_known_feature | \n", + "static_feature | \n", + "static_feature_cat | \n", + "
---|---|---|---|---|---|---|---|---|
0 | \n", + "0 | \n", + "0 | \n", + "0.200968 | \n", + "0.232298 | \n", + "0 | \n", + "1.000000 | \n", + "0.68729 | \n", + "0 | \n", + "
1 | \n", + "0 | \n", + "1 | \n", + "0.232298 | \n", + "0.336669 | \n", + "0 | \n", + "0.995004 | \n", + "0.68729 | \n", + "0 | \n", + "
2 | \n", + "0 | \n", + "2 | \n", + "0.336669 | \n", + "0.636063 | \n", + "0 | \n", + "0.980067 | \n", + "0.68729 | \n", + "0 | \n", + "
3 | \n", + "0 | \n", + "3 | \n", + "0.636063 | \n", + "0.927710 | \n", + "0 | \n", + "0.955336 | \n", + "0.68729 | \n", + "0 | \n", + "
4 | \n", + "0 | \n", + "4 | \n", + "0.927710 | \n", + "1.008554 | \n", + "0 | \n", + "0.921061 | \n", + "0.68729 | \n", + "0 | \n", + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric ┃ DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test_MAE │ 0.45287469029426575 │\n", + "│ test_SMAPE │ 0.942494809627533 │\n", + "│ test_loss │ 0.01396977063268423 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_MAE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.45287469029426575 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_SMAPE \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.942494809627533 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.01396977063268423 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Prediction shape: torch.Size([32, 1, 1])\n", + "First prediction values: [[0.11045122]]\n", + "First true values: [[-0.0491814]]\n", + "\n", + "TFT model test complete!\n" + ] + } + ], + "source": [ + "model = TFT(\n", + " loss=nn.MSELoss(),\n", + " logging_metrics=[MAE(), SMAPE()],\n", + " optimizer=\"adam\",\n", + " optimizer_params={\"lr\": 1e-3},\n", + " lr_scheduler=\"reduce_lr_on_plateau\",\n", + " lr_scheduler_params={\"mode\": \"min\", \"factor\": 0.1, \"patience\": 10},\n", + " hidden_size=64,\n", + " num_layers=2,\n", + " attention_head_size=4,\n", + " dropout=0.1,\n", + " metadata=data_module.metadata,\n", + ")\n", + "\n", + "print(\"\\nTraining model...\")\n", + "trainer = Trainer(\n", + " max_epochs=5,\n", + " accelerator=\"auto\",\n", + " devices=1,\n", + " enable_progress_bar=True,\n", + " log_every_n_steps=10,\n", + ")\n", + "\n", + "trainer.fit(model, data_module)\n", + "\n", + "print(\"\\nEvaluating model...\")\n", + "test_metrics = trainer.test(model, data_module)\n", + "\n", + "model.eval()\n", + "with torch.no_grad():\n", + " test_batch = next(iter(data_module.test_dataloader()))\n", + " x_test, y_test = test_batch\n", + " y_pred = model(x_test)\n", + "\n", + " print(\"\\nPrediction shape:\", y_pred[\"prediction\"].shape)\n", + " print(\"First prediction values:\", y_pred[\"prediction\"][0].cpu().numpy())\n", + " print(\"First true values:\", y_test[0].cpu().numpy())\n", + "print(\"\\nTFT model test complete!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zVRwi2MvLGgc" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/pytorch_forecasting/models/base/_base_model_v2.py b/pytorch_forecasting/models/base/_base_model_v2.py new file mode 100644 index 000000000..ddefc29fb --- /dev/null +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -0,0 +1,296 @@ +######################################################################################## +# Disclaimer: This baseclass is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the base classes may look like +# in the version-2. +######################################################################################## + + +from typing import Dict, List, Optional, Tuple, Union +from warnings import warn + +from lightning.pytorch import LightningModule +from lightning.pytorch.utilities.types import STEP_OUTPUT +import torch +import torch.nn as nn +from torch.optim import Optimizer + + +class BaseModel(LightningModule): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + ): + """ + Base model for time series forecasting. + + Parameters + ---------- + loss : nn.Module + Loss function to use for training. + logging_metrics : Optional[List[nn.Module]], optional + List of metrics to log during training, validation, and testing. + optimizer : Optional[Union[Optimizer, str]], optional + Optimizer to use for training. + Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`. + optimizer_params : Optional[Dict], optional + Parameters for the optimizer. + lr_scheduler : Optional[str], optional + Learning rate scheduler to use. + Supported values: "reduce_lr_on_plateau", "step_lr". + lr_scheduler_params : Optional[Dict], optional + Parameters for the learning rate scheduler. + """ + super().__init__() + self.loss = loss + self.logging_metrics = logging_metrics if logging_metrics is not None else [] + self.optimizer = optimizer + self.optimizer_params = optimizer_params if optimizer_params is not None else {} + self.lr_scheduler = lr_scheduler + self.lr_scheduler_params = ( + lr_scheduler_params if lr_scheduler_params is not None else {} + ) + self.model_name = self.__class__.__name__ + warn( + f"The Model '{self.model_name}' is part of an experimental rework" + "of the pytorch-forecasting model layer, scheduled for release with v2.0.0." + " The API is not stable and may change without prior warning. " + "This class is intended for beta testing and as a basic skeleton, " + "but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors + """ + raise NotImplementedError("Forward method must be implemented by subclass.") + + def training_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Training step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="train") + return {"loss": loss} + + def validation_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Validation step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="val") + return {"val_loss": loss} + + def test_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Test step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="test") + return {"test_loss": loss} + + def predict_step( + self, + batch: Tuple[Dict[str, torch.Tensor]], + batch_idx: int, + dataloader_idx: int = 0, + ) -> torch.Tensor: + """ + Prediction step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input tensors. + batch_idx : int + Index of the batch. + dataloader_idx : int + Index of the dataloader. + + Returns + ------- + torch.Tensor + Predicted output tensor. + """ + x, _ = batch + y_hat = self(x) + return y_hat + + def configure_optimizers(self) -> Dict: + """ + Configure the optimizer and learning rate scheduler. + + Returns + ------- + Dict + Dictionary containing the optimizer and scheduler configuration. + """ + optimizer = self._get_optimizer() + if self.lr_scheduler is not None: + scheduler = self._get_scheduler(optimizer) + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + }, + } + else: + return {"optimizer": optimizer, "lr_scheduler": scheduler} + return {"optimizer": optimizer} + + def _get_optimizer(self) -> Optimizer: + """ + Get the optimizer based on the specified optimizer name and parameters. + + Returns + ------- + Optimizer + The optimizer instance. + """ + if isinstance(self.optimizer, str): + if self.optimizer.lower() == "adam": + return torch.optim.Adam(self.parameters(), **self.optimizer_params) + elif self.optimizer.lower() == "sgd": + return torch.optim.SGD(self.parameters(), **self.optimizer_params) + else: + raise ValueError(f"Optimizer {self.optimizer} not supported.") + elif isinstance(self.optimizer, Optimizer): + return self.optimizer + else: + raise ValueError( + "Optimizer must be either a string or " + "an instance of torch.optim.Optimizer." + ) + + def _get_scheduler( + self, optimizer: Optimizer + ) -> torch.optim.lr_scheduler._LRScheduler: + """ + Get the lr scheduler based on the specified scheduler name and params. + + Parameters + ---------- + optimizer : Optimizer + The optimizer instance. + + Returns + ------- + torch.optim.lr_scheduler._LRScheduler + The learning rate scheduler instance. + """ + if self.lr_scheduler.lower() == "reduce_lr_on_plateau": + return torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, **self.lr_scheduler_params + ) + elif self.lr_scheduler.lower() == "step_lr": + return torch.optim.lr_scheduler.StepLR( + optimizer, **self.lr_scheduler_params + ) + else: + raise ValueError(f"Scheduler {self.lr_scheduler} not supported.") + + def log_metrics( + self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = "val" + ) -> None: + """ + Log additional metrics during training, validation, or testing. + + Parameters + ---------- + y_hat : torch.Tensor + Predicted output tensor. + y : torch.Tensor + Target output tensor. + prefix : str + Prefix for the logged metrics (e.g., "train", "val", "test"). + """ + for metric in self.logging_metrics: + metric_value = metric(y_hat, y) + self.log( + f"{prefix}_{metric.__class__.__name__}", + metric_value, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py new file mode 100644 index 000000000..a0cf7d39e --- /dev/null +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py @@ -0,0 +1,253 @@ +######################################################################################## +# Disclaimer: This implementation is based on the new version of data pipeline and is +# experimental, please use with care. +######################################################################################## + +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch.optim import Optimizer + +from pytorch_forecasting.models.base._base_model_v2 import BaseModel + + +class TFT(BaseModel): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + hidden_size: int = 64, + num_layers: int = 2, + attention_head_size: int = 4, + dropout: float = 0.1, + metadata: Optional[Dict] = None, + output_size: int = 1, + ): + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + optimizer=optimizer, + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + lr_scheduler_params=lr_scheduler_params, + ) + self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) + + self.hidden_size = hidden_size + self.num_layers = num_layers + self.attention_head_size = attention_head_size + self.dropout = dropout + self.metadata = metadata + self.output_size = output_size + + self.max_encoder_length = self.metadata["max_encoder_length"] + self.max_prediction_length = self.metadata["max_prediction_length"] + self.encoder_cont = self.metadata["encoder_cont"] + self.encoder_cat = self.metadata["encoder_cat"] + self.encoder_input_dim = self.encoder_cont + self.encoder_cat + self.decoder_cont = self.metadata["decoder_cont"] + self.decoder_cat = self.metadata["decoder_cat"] + self.decoder_input_dim = self.decoder_cont + self.decoder_cat + self.static_cat_dim = self.metadata.get("static_categorical_features", 0) + self.static_cont_dim = self.metadata.get("static_continuous_features", 0) + self.static_input_dim = self.static_cat_dim + self.static_cont_dim + + if self.encoder_input_dim > 0: + self.encoder_var_selection = nn.Sequential( + nn.Linear(self.encoder_input_dim, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, self.encoder_input_dim), + nn.Sigmoid(), + ) + else: + self.encoder_var_selection = None + + if self.decoder_input_dim > 0: + self.decoder_var_selection = nn.Sequential( + nn.Linear(self.decoder_input_dim, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, self.decoder_input_dim), + nn.Sigmoid(), + ) + else: + self.decoder_var_selection = None + + if self.static_input_dim > 0: + self.static_context_linear = nn.Linear(self.static_input_dim, hidden_size) + else: + self.static_context_linear = None + + _lstm_encoder_input_actual_dim = self.encoder_input_dim + self.lstm_encoder = nn.LSTM( + input_size=max(1, _lstm_encoder_input_actual_dim), + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + _lstm_decoder_input_actual_dim = self.decoder_input_dim + self.lstm_decoder = nn.LSTM( + input_size=max(1, _lstm_decoder_input_actual_dim), + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + self.self_attention = nn.MultiheadAttention( + embed_dim=hidden_size, + num_heads=attention_head_size, + dropout=dropout, + batch_first=True, + ) + + self.pre_output = nn.Linear(hidden_size, hidden_size) + self.output_layer = nn.Linear(hidden_size, self.output_size) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the TFT model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors: + - encoder_cat: Categorical encoder features + - encoder_cont: Continuous encoder features + - decoder_cat: Categorical decoder features + - decoder_cont: Continuous decoder features + - static_categorical_features: Static categorical features + - static_continuous_features: Static continuous features + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors: + - prediction: Prediction output (batch_size, prediction_length, output_size) + """ + batch_size = x["encoder_cont"].shape[0] + + encoder_cat = x.get( + "encoder_cat", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + encoder_cont = x.get( + "encoder_cont", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + decoder_cat = x.get( + "decoder_cat", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + decoder_cont = x.get( + "decoder_cont", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + + encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2) + decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2) + + static_context = None + if self.static_context_linear is not None: + static_cat = x.get( + "static_categorical_features", + torch.zeros(batch_size, 1, 0, device=self.device), + ) + static_cont = x.get( + "static_continuous_features", + torch.zeros(batch_size, 1, 0, device=self.device), + ) + + if static_cat.size(2) == 0 and static_cont.size(2) == 0: + static_context = None + elif static_cat.size(2) == 0: + static_input = static_cont.to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + elif static_cont.size(2) == 0: + static_input = static_cat.to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + else: + + static_input = torch.cat([static_cont, static_cat], dim=2).to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + + if self.encoder_var_selection is not None: + encoder_weights = self.encoder_var_selection(encoder_input) + encoder_input = encoder_input * encoder_weights + else: + if self.encoder_input_dim == 0: + encoder_input = torch.zeros( + batch_size, + self.max_encoder_length, + 1, + device=self.device, + dtype=encoder_input.dtype, + ) + else: + encoder_input = encoder_input + + if self.decoder_var_selection is not None: + decoder_weights = self.decoder_var_selection(decoder_input) + decoder_input = decoder_input * decoder_weights + else: + if self.decoder_input_dim == 0: + decoder_input = torch.zeros( + batch_size, + self.max_prediction_length, + 1, + device=self.device, + dtype=decoder_input.dtype, + ) + else: + decoder_input = decoder_input + + if static_context is not None: + encoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_encoder_length, -1 + ) + decoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_prediction_length, -1 + ) + + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + encoder_output = encoder_output + encoder_static_context + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + decoder_output = decoder_output + decoder_static_context + else: + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + + sequence = torch.cat([encoder_output, decoder_output], dim=1) + + if static_context is not None: + expanded_static_context = static_context.unsqueeze(1).expand( + -1, sequence.size(1), -1 + ) + + attended_output, _ = self.self_attention( + sequence + expanded_static_context, sequence, sequence + ) + else: + attended_output, _ = self.self_attention(sequence, sequence, sequence) + + decoder_attended = attended_output[:, -self.max_prediction_length :, :] + + output = nn.functional.relu(self.pre_output(decoder_attended)) + prediction = self.output_layer(output) + + return {"prediction": prediction} diff --git a/tests/test_models/_test_tft_v2.py b/tests/test_models/_test_tft_v2.py new file mode 100644 index 000000000..13d92d5db --- /dev/null +++ b/tests/test_models/_test_tft_v2.py @@ -0,0 +1,398 @@ +import numpy as np +import pandas as pd +import pytest +import torch +import torch.nn as nn + +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule +from pytorch_forecasting.data.timeseries import TimeSeries +from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT + +BATCH_SIZE_TEST = 2 +MAX_ENCODER_LENGTH_TEST = 10 +MAX_PREDICTION_LENGTH_TEST = 5 +HIDDEN_SIZE_TEST = 8 +OUTPUT_SIZE_TEST = 1 +ATTENTION_HEAD_SIZE_TEST = 2 +NUM_LAYERS_TEST = 1 +DROPOUT_TEST = 0.1 + + +def get_default_test_metadata( + enc_cont=2, + enc_cat=1, + dec_cont=1, + dec_cat=1, + static_cat=1, + static_cont=1, + output_size=OUTPUT_SIZE_TEST, +): + """Return a dict representing default metadata for TFT model initialization.""" + return { + "max_encoder_length": MAX_ENCODER_LENGTH_TEST, + "max_prediction_length": MAX_PREDICTION_LENGTH_TEST, + "encoder_cont": enc_cont, + "encoder_cat": enc_cat, + "decoder_cont": dec_cont, + "decoder_cat": dec_cat, + "static_categorical_features": static_cat, + "static_continuous_features": static_cont, + "target": output_size, + } + + +def create_tft_input_batch_for_test(metadata, batch_size=BATCH_SIZE_TEST, device="cpu"): + """Create a synthetic input batch dictionary for testing TFT forward passes.""" + + def _get_dim_val(key): + return metadata.get(key, 0) + + x = { + "encoder_cont": torch.randn( + batch_size, + metadata["max_encoder_length"], + _get_dim_val("encoder_cont"), + device=device, + ), + "encoder_cat": torch.randn( + batch_size, + metadata["max_encoder_length"], + _get_dim_val("encoder_cat"), + device=device, + ), + "decoder_cont": torch.randn( + batch_size, + metadata["max_prediction_length"], + _get_dim_val("decoder_cont"), + device=device, + ), + "decoder_cat": torch.randn( + batch_size, + metadata["max_prediction_length"], + _get_dim_val("decoder_cat"), + device=device, + ), + "static_categorical_features": torch.randn( + batch_size, 1, _get_dim_val("static_categorical_features"), device=device + ), + "static_continuous_features": torch.randn( + batch_size, 1, _get_dim_val("static_continuous_features"), device=device + ), + "encoder_lengths": torch.full( + (batch_size,), + metadata["max_encoder_length"], + dtype=torch.long, + device=device, + ), + "decoder_lengths": torch.full( + (batch_size,), + metadata["max_prediction_length"], + dtype=torch.long, + device=device, + ), + "groups": torch.arange(batch_size, device=device).unsqueeze(1), + "encoder_time_idx": torch.stack( + [torch.arange(metadata["max_encoder_length"], device=device)] * batch_size + ), + "decoder_time_idx": torch.stack( + [ + torch.arange( + metadata["max_encoder_length"], + metadata["max_encoder_length"] + metadata["max_prediction_length"], + device=device, + ) + ] + * batch_size + ), + "target_scale": torch.ones((batch_size, 1), device=device), + } + return x + + +dummy_loss_for_test = nn.MSELoss() + + +@pytest.fixture(scope="module") +def tft_model_params_fixture_func(): + """Create a default set of model parameters for TFT.""" + return { + "loss": dummy_loss_for_test, + "hidden_size": HIDDEN_SIZE_TEST, + "num_layers": NUM_LAYERS_TEST, + "attention_head_size": ATTENTION_HEAD_SIZE_TEST, + "dropout": DROPOUT_TEST, + "output_size": OUTPUT_SIZE_TEST, + } + + +def test_basic_initialization(tft_model_params_fixture_func): + """Test basic initialization of the TFT model with default metadata. + + Verifies: + - Model attributes match the provided metadata (e.g., hidden_size, num_layers). + - Proper construction of key model components (LSTM, attention, etc.). + - Correct dimensionality of input layers based on metadata. + - Model retains metadata and hyperparameters as expected. + """ + metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.hidden_size == HIDDEN_SIZE_TEST + assert model.num_layers == NUM_LAYERS_TEST + assert hasattr(model, "metadata") and model.metadata == metadata + assert model.encoder_input_dim == metadata["encoder_cont"] + metadata["encoder_cat"] + assert ( + model.static_input_dim + == metadata["static_categorical_features"] + + metadata["static_continuous_features"] + ) + assert isinstance(model.lstm_encoder, nn.LSTM) + assert model.lstm_encoder.input_size == max(1, model.encoder_input_dim) + assert isinstance(model.self_attention, nn.MultiheadAttention) + if hasattr(model, "hparams") and model.hparams: + assert model.hparams.get("hidden_size") == HIDDEN_SIZE_TEST + assert model.output_size == OUTPUT_SIZE_TEST + + +def test_initialization_no_time_varying_features(tft_model_params_fixture_func): + """Test TFT initialization with no time-varying (encoder/decoder) features. + + Verifies: + - Model handles zero encoder/decoder input dimensions correctly. + - Skips creation of encoder/decoder variable selection networks. + - Defaults to input size 1 for LSTMs when no time-varying features exist. + """ + metadata = get_default_test_metadata( + enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.encoder_input_dim == 0 + assert model.encoder_var_selection is None + assert model.lstm_encoder.input_size == 1 + assert model.decoder_input_dim == 0 + assert model.decoder_var_selection is None + assert model.lstm_decoder.input_size == 1 + + +def test_initialization_no_static_features(tft_model_params_fixture_func): + """Test TFT initialization with no static features. + + Verifies: + - Model static input dim is 0. + - Static context linear layer is not created. + """ + metadata = get_default_test_metadata( + static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.static_input_dim == 0 + assert model.static_context_linear is None + + +@pytest.mark.parametrize( + "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", + [ + (2, 1, 1, 1, 1, 1), + (2, 0, 1, 0, 0, 0), + (0, 0, 0, 0, 1, 1), + (0, 0, 0, 0, 0, 0), + (1, 0, 1, 0, 1, 0), + (1, 0, 1, 0, 0, 1), + ], +) +def test_forward_pass_configs( + tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k +): + """Test TFT forward pass across multiple feature configurations. + + Verifies: + - Model can forward pass without errors for varying combinations of input types. + - Output prediction tensor has expected shape. + - Output contains no NaNs or infinities. + """ + current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] + metadata = get_default_test_metadata( + enc_cont=enc_c, + enc_cat=enc_k, + dec_cont=dec_c, + dec_cat=dec_k, + static_cat=stat_c, + static_cont=stat_k, + output_size=current_tft_actual_output_size, + ) + model_params = tft_model_params_fixture_func.copy() + model_params["output_size"] = current_tft_actual_output_size + model = TFT(**model_params, metadata=metadata) + model.eval() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + x = create_tft_input_batch_for_test( + metadata, batch_size=BATCH_SIZE_TEST, device=device + ) + output_dict = model(x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + BATCH_SIZE_TEST, + MAX_PREDICTION_LENGTH_TEST, + current_tft_actual_output_size, + ) + assert not torch.isnan(predictions).any(), "NaNs in prediction" + assert not torch.isinf(predictions).any(), "Infs in prediction" + + +@pytest.fixture +def sample_pandas_data_for_test(): + """Create synthetic multivariate time series data as a pandas DataFrame.""" + series_len = MAX_ENCODER_LENGTH_TEST + MAX_PREDICTION_LENGTH_TEST + 5 + num_groups = 6 + data = [] + + for i in range(num_groups): + static_cont_val = np.float32(i * 10.0) + static_cat_code = np.float32(i % 2) + + df_group = pd.DataFrame( + { + "time_idx": np.arange(series_len, dtype=np.int64), + "group_id_str": np.repeat(f"g{i}", series_len), + "target": np.random.rand(series_len).astype(np.float32) + i, + "enc_cont1": np.random.rand(series_len).astype(np.float32), + "enc_cat1_codes": np.random.randint(0, 3, series_len).astype( + np.float32 + ), + "dec_known_cont": np.sin(np.arange(series_len) / 5.0).astype( + np.float32 + ), + "dec_known_cat_codes": np.random.randint(0, 2, series_len).astype( + np.float32 + ), + "static_cat_feat_codes": np.full( + series_len, static_cat_code, dtype=np.float32 + ), + "static_cont_feat": np.full( + series_len, static_cont_val, dtype=np.float32 + ), + } + ) + data.append(df_group) + + df = pd.concat(data, ignore_index=True) + + df["group_id"] = df["group_id_str"].astype("category") + df.drop(columns=["group_id_str"], inplace=True) + + return df + + +@pytest.fixture +def timeseries_obj_for_test(sample_pandas_data_for_test): + """Convert sample DataFrame into a TimeSeries object.""" + df = sample_pandas_data_for_test + + return TimeSeries( + data=df, + time="time_idx", + target="target", + group=["group_id"], + num=[ + "enc_cont1", + "enc_cat1_codes", + "dec_known_cont", + "dec_known_cat_codes", + "static_cat_feat_codes", + "static_cont_feat", + ], + cat=[], + known=["dec_known_cont", "dec_known_cat_codes", "time_idx"], + static=["static_cat_feat_codes", "static_cont_feat"], + ) + + +@pytest.fixture +def data_module_for_test(timeseries_obj_for_test): + """Initialize and sets up an EncoderDecoderTimeSeriesDataModule.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=timeseries_obj_for_test, + batch_size=BATCH_SIZE_TEST, + max_encoder_length=MAX_ENCODER_LENGTH_TEST, + max_prediction_length=MAX_PREDICTION_LENGTH_TEST, + train_val_test_split=(0.5, 0.25, 0.25), + ) + dm.setup("fit") + dm.setup("test") + return dm + + +def test_model_with_datamodule_integration( + tft_model_params_fixture_func, data_module_for_test +): + """Integration test to ensure TFT works correctly with data module. + + Verifies: + - Metadata inferred from data module matches expected input dimensions. + - Model processes real dataloader batches correctly. + - Output and target tensors from model and data module align in shape. + - No NaNs in predictions. + """ + dm = data_module_for_test + model_metadata_from_dm = dm.metadata + + assert ( + model_metadata_from_dm["encoder_cont"] == 6 + ), f"Actual encoder_cont: {model_metadata_from_dm['encoder_cont']}" + assert ( + model_metadata_from_dm["encoder_cat"] == 0 + ), f"Actual encoder_cat: {model_metadata_from_dm['encoder_cat']}" + assert ( + model_metadata_from_dm["decoder_cont"] == 2 + ), f"Actual decoder_cont: {model_metadata_from_dm['decoder_cont']}" + assert ( + model_metadata_from_dm["decoder_cat"] == 0 + ), f"Actual decoder_cat: {model_metadata_from_dm['decoder_cat']}" + assert ( + model_metadata_from_dm["static_categorical_features"] == 0 + ), f"Actual static_cat: {model_metadata_from_dm['static_categorical_features']}" + assert ( + model_metadata_from_dm["static_continuous_features"] == 2 + ), f"Actual static_cont: {model_metadata_from_dm['static_continuous_features']}" + assert model_metadata_from_dm["target"] == 1 + + tft_init_args = tft_model_params_fixture_func.copy() + tft_init_args["output_size"] = model_metadata_from_dm["target"] + model = TFT(**tft_init_args, metadata=model_metadata_from_dm) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + + train_loader = dm.train_dataloader() + batch_x, batch_y = next(iter(train_loader)) + + actual_batch_size = batch_x["encoder_cont"].shape[0] + batch_x = {k: v.to(device) for k, v in batch_x.items()} + batch_y = batch_y.to(device) + + assert batch_x["encoder_cont"].shape[2] == model_metadata_from_dm["encoder_cont"] + assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] + assert batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] + assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] + assert ( + batch_x["static_categorical_features"].shape[2] + == model_metadata_from_dm["static_categorical_features"] + ) + assert ( + batch_x["static_continuous_features"].shape[2] + == model_metadata_from_dm["static_continuous_features"] + ) + + output_dict = model(batch_x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) + assert not torch.isnan(predictions).any() + assert batch_y.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + )