diff --git a/README.md b/README.md index 103df83..9189f4d 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,8 @@ The implementation is tested to be compatatible with Tensorflow, Jax and Torch. It is the original implementation of the [paper](https://arxiv.org/abs/2405.07344) The KAN part implementation has been inspired from [efficient_kan](https://github.com/Blealtan/efficient-kan), and is available [here](https://github.com/remigenet/keras_efficient_kan) and works similarly to it, thus not exactly like the [original implementation](https://github.com/KindXiaoming/pykan). +In case of performance consideration, the best setup tested used [jax docker image](https://hub.docker.com/r/bitnami/jax/) followed by installing jax using ```pip install "jax[cuda12]"```, this is what is used in the example section where you can compare the TKAN vs LSTM vs GRU time and performance. +I also discourage using as is the example for torch, it seems that currently when running test using torch backend with keras is much slower than torch directly, even for GRU or LSTM. ![TKAN representation](image/TKAN.drawio.png) diff --git a/examples/simple_example_tkan.ipynb b/examples/simple_example_tkan.ipynb index b75576d..7ccb161 100644 --- a/examples/simple_example_tkan.ipynb +++ b/examples/simple_example_tkan.ipynb @@ -5,113 +5,88 @@ "id": "bb186818-bd1d-46ed-a018-27efa013b206", "metadata": {}, "source": [ - "# " + "# TKAN example and comparison with benchmarks\n", + "\n", + "All test have been run on a RTX 4070 with an Core™ i7-6700K on vast.ai using this [jax docker image](https://hub.docker.com/r/bitnami/jax/)\n", + "\n", + "tkan version: 0.4.1" ] }, { "cell_type": "code", "execution_count": 1, "id": "bc3f1ac2-1785-4e08-89a5-0bc5e31ce2b7", - "metadata": {}, + "metadata": { + "scrolled": true + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (2.2.2)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (1.26.4)\n", - "Requirement already satisfied: matplotlib in /usr/local/lib/python3.11/dist-packages (3.9.1)\n", - "Requirement already satisfied: pyarrow in /usr/local/lib/python3.11/dist-packages (17.0.0)\n", - "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (1.5.1)\n", - "Requirement already satisfied: tkan in /usr/local/lib/python3.11/dist-packages (0.4.0)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas) (2.9.0.post0)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas) (2024.1)\n", - "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas) (2024.1)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (1.2.1)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (0.12.1)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (4.53.1)\n", - "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (1.4.5)\n", - "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (24.1)\n", - "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (10.4.0)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /usr/lib/python3/dist-packages (from matplotlib) (2.4.7)\n", - "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.14.0)\n", - "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.4.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (3.5.0)\n", - "Requirement already satisfied: keras<4.0,>=3.0.0 in /usr/local/lib/python3.11/dist-packages (from tkan) (3.4.1)\n", - "Requirement already satisfied: keras_efficient_kan<0.2.0,>=0.1.4 in /usr/local/lib/python3.11/dist-packages (from tkan) (0.1.4)\n", - "Requirement already satisfied: absl-py in /usr/local/lib/python3.11/dist-packages (from keras<4.0,>=3.0.0->tkan) (2.1.0)\n", - "Requirement already satisfied: rich in /usr/local/lib/python3.11/dist-packages (from keras<4.0,>=3.0.0->tkan) (13.7.1)\n", - "Requirement already satisfied: namex in /usr/local/lib/python3.11/dist-packages (from keras<4.0,>=3.0.0->tkan) (0.0.8)\n", - "Requirement already satisfied: h5py in /usr/local/lib/python3.11/dist-packages (from keras<4.0,>=3.0.0->tkan) (3.11.0)\n", - "Requirement already satisfied: optree in /usr/local/lib/python3.11/dist-packages (from keras<4.0,>=3.0.0->tkan) (0.12.1)\n", - "Requirement already satisfied: ml-dtypes in /usr/local/lib/python3.11/dist-packages (from keras<4.0,>=3.0.0->tkan) (0.4.0)\n", - "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n", - "Requirement already satisfied: typing-extensions>=4.5.0 in /usr/local/lib/python3.11/dist-packages (from optree->keras<4.0,>=3.0.0->tkan) (4.12.2)\n", - "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich->keras<4.0,>=3.0.0->tkan) (3.0.0)\n", - "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich->keras<4.0,>=3.0.0->tkan) (2.18.0)\n", - "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich->keras<4.0,>=3.0.0->tkan) (0.1.2)\n", + "\u001b[33mDEPRECATION: Loading egg at /opt/bitnami/python/lib/python3.11/site-packages/pip-23.3.2-py3.11.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: pandas in /opt/bitnami/python/lib/python3.11/site-packages (2.2.2)\n", + "Requirement already satisfied: numpy in /opt/bitnami/python/lib/python3.11/site-packages (1.26.4)\n", + "Requirement already satisfied: matplotlib in /opt/bitnami/python/lib/python3.11/site-packages (3.9.1)\n", + "Requirement already satisfied: pyarrow in /opt/bitnami/python/lib/python3.11/site-packages (17.0.0)\n", + "Requirement already satisfied: scikit-learn in /opt/bitnami/python/lib/python3.11/site-packages (1.5.1)\n", + "Requirement already satisfied: tkan in /opt/bitnami/python/lib/python3.11/site-packages (0.4.1)\n", + "Requirement already satisfied: jax[cuda12] in /opt/bitnami/python/lib/python3.11/site-packages (0.4.30.dev20240618+f4158ac)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/bitnami/python/lib/python3.11/site-packages (from pandas) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /opt/bitnami/python/lib/python3.11/site-packages (from pandas) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.7 in /opt/bitnami/python/lib/python3.11/site-packages (from pandas) (2024.1)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (1.2.1)\n", + "Requirement already satisfied: cycler>=0.10 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (4.53.1)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (1.4.5)\n", + "Requirement already satisfied: packaging>=20.0 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (24.1)\n", + "Requirement already satisfied: pillow>=8 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (10.4.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (3.1.2)\n", + "Requirement already satisfied: scipy>=1.6.0 in /opt/bitnami/python/lib/python3.11/site-packages (from scikit-learn) (1.14.0)\n", + "Requirement already satisfied: joblib>=1.2.0 in /opt/bitnami/python/lib/python3.11/site-packages (from scikit-learn) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=3.1.0 in /opt/bitnami/python/lib/python3.11/site-packages (from scikit-learn) (3.5.0)\n", + "Requirement already satisfied: keras<4.0,>=3.0.0 in /opt/bitnami/python/lib/python3.11/site-packages (from tkan) (3.4.1)\n", + "Requirement already satisfied: keras_efficient_kan<0.2.0,>=0.1.4 in /opt/bitnami/python/lib/python3.11/site-packages (from tkan) (0.1.4)\n", + "Requirement already satisfied: jaxlib<=0.4.30,>=0.4.27 in /opt/bitnami/python/lib/python3.11/site-packages (from jax[cuda12]) (0.4.30)\n", + "Requirement already satisfied: ml-dtypes>=0.2.0 in /opt/bitnami/python/lib/python3.11/site-packages (from jax[cuda12]) (0.4.0)\n", + "Requirement already satisfied: opt-einsum in /opt/bitnami/python/lib/python3.11/site-packages (from jax[cuda12]) (3.3.0)\n", + "Requirement already satisfied: jax-cuda12-plugin<=0.4.30,>=0.4.30 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (0.4.30)\n", + "Requirement already satisfied: jax-cuda12-pjrt==0.4.30 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin<=0.4.30,>=0.4.30->jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (0.4.30)\n", + "Requirement already satisfied: nvidia-cublas-cu12>=12.1.3.1 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.3.2)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12>=12.1.105 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n", + "Requirement already satisfied: nvidia-cuda-nvcc-cu12>=12.1.105 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12>=12.1.105 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n", + "Requirement already satisfied: nvidia-cudnn-cu12<10.0,>=9.0 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (9.2.1.18)\n", + "Requirement already satisfied: nvidia-cufft-cu12>=11.0.2.54 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (11.2.3.61)\n", + "Requirement already satisfied: nvidia-cusolver-cu12>=11.4.5.107 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (11.6.3.83)\n", + "Requirement already satisfied: nvidia-cusparse-cu12>=12.1.0.106 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.1.3)\n", + "Requirement already satisfied: nvidia-nccl-cu12>=2.18.1 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (2.22.3)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12>=12.1.105 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n", + "Requirement already satisfied: absl-py in /opt/bitnami/python/lib/python3.11/site-packages (from keras<4.0,>=3.0.0->tkan) (2.1.0)\n", + "Requirement already satisfied: rich in /opt/bitnami/python/lib/python3.11/site-packages (from keras<4.0,>=3.0.0->tkan) (13.7.1)\n", + "Requirement already satisfied: namex in /opt/bitnami/python/lib/python3.11/site-packages (from keras<4.0,>=3.0.0->tkan) (0.0.8)\n", + "Requirement already satisfied: h5py in /opt/bitnami/python/lib/python3.11/site-packages (from keras<4.0,>=3.0.0->tkan) (3.11.0)\n", + "Requirement already satisfied: optree in /opt/bitnami/python/lib/python3.11/site-packages (from keras<4.0,>=3.0.0->tkan) (0.12.1)\n", + "Requirement already satisfied: six>=1.5 in /opt/bitnami/python/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n", + "Requirement already satisfied: typing-extensions>=4.5.0 in /opt/bitnami/python/lib/python3.11/site-packages (from optree->keras<4.0,>=3.0.0->tkan) (4.12.2)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/bitnami/python/lib/python3.11/site-packages (from rich->keras<4.0,>=3.0.0->tkan) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/bitnami/python/lib/python3.11/site-packages (from rich->keras<4.0,>=3.0.0->tkan) (2.18.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /opt/bitnami/python/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich->keras<4.0,>=3.0.0->tkan) (0.1.2)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ - "!pip install pandas numpy matplotlib pyarrow scikit-learn tkan " + "!pip install pandas numpy matplotlib pyarrow scikit-learn tkan \"jax[cuda12]\"" ] }, { "cell_type": "code", "execution_count": 2, - "id": "4bf2cb3c-56e6-49a6-91b8-47e616b850dc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: jax[cuda12] in /usr/local/lib/python3.11/dist-packages (0.4.30)\n", - "Requirement already satisfied: jaxlib<=0.4.30,>=0.4.27 in /usr/local/lib/python3.11/dist-packages (from jax[cuda12]) (0.4.30)\n", - "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from jax[cuda12]) (0.4.0)\n", - "Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.11/dist-packages (from jax[cuda12]) (1.26.4)\n", - "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.11/dist-packages (from jax[cuda12]) (3.3.0)\n", - "Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.11/dist-packages (from jax[cuda12]) (1.14.0)\n", - "Requirement already satisfied: jax-cuda12-plugin<=0.4.30,>=0.4.30 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (0.4.30)\n", - "Requirement already satisfied: jax-cuda12-pjrt==0.4.30 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin<=0.4.30,>=0.4.30->jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (0.4.30)\n", - "Requirement already satisfied: nvidia-cublas-cu12>=12.1.3.1 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.3.2)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu12>=12.1.105 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n", - "Requirement already satisfied: nvidia-cuda-nvcc-cu12>=12.1.105 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu12>=12.1.105 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n", - "Requirement already satisfied: nvidia-cudnn-cu12<10.0,>=9.0 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (9.2.1.18)\n", - "Requirement already satisfied: nvidia-cufft-cu12>=11.0.2.54 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (11.2.3.61)\n", - "Requirement already satisfied: nvidia-cusolver-cu12>=11.4.5.107 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (11.6.3.83)\n", - "Requirement already satisfied: nvidia-cusparse-cu12>=12.1.0.106 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.1.3)\n", - "Requirement already satisfied: nvidia-nccl-cu12>=2.18.1 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (2.22.3)\n", - "Requirement already satisfied: nvidia-nvjitlink-cu12>=12.1.105 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "!pip install -U \"jax[cuda12]\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, "id": "34213122-fabb-4d55-a918-a337be21b974", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-07-23 18:15:27.842138: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2024-07-23 18:15:27.861965: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2024-07-23 18:15:27.867851: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "BACKEND = 'jax' # You can use any backend here \n", @@ -166,7 +141,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "735c5c88-924a-426e-b10f-a05e7b0556f2", "metadata": {}, "outputs": [ @@ -571,7 +546,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "34c9f03b-11bf-4c88-abcb-5fc8e43c6008", "metadata": {}, "outputs": [], @@ -694,17 +669,10 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "27b4f64a-08da-4cfc-97c9-1d5145887c23", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-07-23 18:15:35.115393: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n" - ] - }, { "data": { "text/html": [ @@ -724,11 +692,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ tkan (TKAN)                     │ (None, 75, 100)        │        41,750 │\n",
+       "│ tkan (TKAN)                     │ (None, 45, 100)        │        41,316 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
        "│ tkan_1 (TKAN)                   │ (None, 100)            │        67,670 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense (Dense)                   │ (None, 15)             │         1,515 │\n",
+       "│ dense (Dense)                   │ (None, 1)              │           101 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -736,11 +704,11 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ tkan (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m75\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,750\u001b[0m │\n", + "│ tkan (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,316\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ tkan_1 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m) │ \u001b[38;5;34m1,515\u001b[0m │\n", + "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m101\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -750,11 +718,11 @@ { "data": { "text/html": [ - "
 Total params: 110,935 (433.34 KB)\n",
+       "
 Total params: 109,087 (426.12 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m110,935\u001b[0m (433.34 KB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m109,087\u001b[0m (426.12 KB)\n" ] }, "metadata": {}, @@ -763,11 +731,11 @@ { "data": { "text/html": [ - "
 Trainable params: 110,915 (433.26 KB)\n",
+       "
 Trainable params: 109,067 (426.04 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m110,915\u001b[0m (433.26 KB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m109,067\u001b[0m (426.04 KB)\n" ] }, "metadata": {}, @@ -790,11 +758,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "124.26042032241821 0.09158093535460268\n", - "162.43516397476196 0.10293650922567563\n", - "165.25300359725952 0.09519336968632934\n", - "146.79839873313904 0.08672428978473426\n", - "145.06835103034973 0.09001762052527772\n" + "87.7924542427063 0.2883694212812108\n", + "58.81150460243225 0.2996397661396135\n", + "93.38526654243469 0.3341084410571994\n", + "69.46869540214539 0.31051983221133617\n", + "68.03620338439941 0.2921106524335578\n", + "69.86754989624023 0.31632189994316484\n", + "83.32379245758057 0.3110331089568017\n", + "87.04948496818542 0.32794331666897736\n", + "84.89363670349121 0.31305216100227273\n", + "65.98418259620667 0.30768323851050017\n" ] }, { @@ -816,11 +789,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ gru (GRU)                       │ (None, 75, 100)        │        36,300 │\n",
+       "│ gru (GRU)                       │ (None, 45, 100)        │        36,300 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
        "│ gru_1 (GRU)                     │ (None, 100)            │        60,600 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_5 (Dense)                 │ (None, 15)             │         1,515 │\n",
+       "│ dense_10 (Dense)                │ (None, 1)              │           101 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -828,11 +801,11 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ gru (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m75\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n", + "│ gru (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ gru_1 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_5 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m) │ \u001b[38;5;34m1,515\u001b[0m │\n", + "│ dense_10 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m101\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -842,11 +815,11 @@ { "data": { "text/html": [ - "
 Total params: 98,415 (384.43 KB)\n",
+       "
 Total params: 97,001 (378.91 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m98,415\u001b[0m (384.43 KB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,001\u001b[0m (378.91 KB)\n" ] }, "metadata": {}, @@ -855,11 +828,11 @@ { "data": { "text/html": [ - "
 Trainable params: 98,415 (384.43 KB)\n",
+       "
 Trainable params: 97,001 (378.91 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m98,415\u001b[0m (384.43 KB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,001\u001b[0m (378.91 KB)\n" ] }, "metadata": {}, @@ -882,11 +855,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "40.458553075790405 0.09533817072522596\n", - "41.53169059753418 0.06848331293892884\n", - "41.19035625457764 0.08615624519190934\n", - "47.83834481239319 0.05216959504533443\n", - "41.37253522872925 0.09000230869662042\n" + "25.800300121307373 0.39845659417931767\n", + "33.090811014175415 0.3765754299721896\n", + "29.157642126083374 0.3996904390926539\n", + "20.830080032348633 0.39383111789434266\n", + "24.594220638275146 0.39854787088819454\n", + "25.829734086990356 0.3980608345819734\n", + "20.8519549369812 0.39818872394114213\n", + "29.55962371826172 0.39876597794912183\n", + "28.695900917053223 0.4014753210019817\n", + "31.473294496536255 0.385033271890177\n" ] }, { @@ -908,11 +886,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ lstm (LSTM)                     │ (None, 75, 100)        │        48,000 │\n",
+       "│ lstm (LSTM)                     │ (None, 45, 100)        │        48,000 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
        "│ lstm_1 (LSTM)                   │ (None, 100)            │        80,400 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_10 (Dense)                │ (None, 15)             │         1,515 │\n",
+       "│ dense_20 (Dense)                │ (None, 1)              │           101 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -920,11 +898,11 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ lstm (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m75\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n", + "│ lstm (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ lstm_1 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_10 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m) │ \u001b[38;5;34m1,515\u001b[0m │\n", + "│ dense_20 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m101\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -934,11 +912,11 @@ { "data": { "text/html": [ - "
 Total params: 129,915 (507.48 KB)\n",
+       "
 Total params: 128,501 (501.96 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,915\u001b[0m (507.48 KB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m128,501\u001b[0m (501.96 KB)\n" ] }, "metadata": {}, @@ -947,11 +925,11 @@ { "data": { "text/html": [ - "
 Trainable params: 129,915 (507.48 KB)\n",
+       "
 Trainable params: 128,501 (501.96 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,915\u001b[0m (507.48 KB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m128,501\u001b[0m (501.96 KB)\n" ] }, "metadata": {}, @@ -974,583 +952,220 @@ "name": "stdout", "output_type": "stream", "text": [ - "34.312363147735596 -0.19760337708902084\n", - "31.26186227798462 -0.17731950497927035\n", - "34.37222385406494 -0.37332071192906224\n", - "35.09816241264343 -0.08747675367134147\n", - "33.26098704338074 -0.3229705947886879\n", - "R2 scores\n", - "Means:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" + "17.663942098617554 0.3868841900601451\n", + "18.93087387084961 0.36429593525523807\n", + "17.406195163726807 0.3900695412878702\n", + "20.44808530807495 0.3786097961170115\n", + "16.755755186080933 0.392573888221295\n", + "20.213451862335205 0.3840845450430137\n", + "22.768882513046265 0.36630601832303555\n", + "18.173134088516235 0.3832138605836025\n", + "17.591335773468018 0.3918248620808171\n", + "19.007630586624146 0.37993474576853103\n" ] }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0932910.07843-0.231738
12NaNNaNNaN
9NaNNaNNaN
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" + "
Model: \"TKAN\"\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.093291 0.07843 -0.231738\n", - "12 NaN NaN NaN\n", - "9 NaN NaN NaN\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" + "\u001b[1mModel: \"TKAN\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0677050.0681940.078686
12NaNNaNNaN
9NaNNaNNaN
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+       "│ tkan_20 (TKAN)                  │ (None, 45, 100)        │        41,316 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ tkan_21 (TKAN)                  │ (None, 100)            │        67,670 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dense_30 (Dense)                │ (None, 3)              │           303 │\n",
+       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.067705 0.068194 0.078686\n", - "12 NaN NaN NaN\n", - "9 NaN NaN NaN\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", + "│ tkan_20 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,316\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ tkan_21 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dense_30 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m303\u001b[0m │\n", + "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Std:\n" - ] + "data": { + "text/html": [ + "
 Total params: 109,289 (426.91 KB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m109,289\u001b[0m (426.91 KB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] + "data": { + "text/html": [ + "
 Trainable params: 109,269 (426.83 KB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m109,269\u001b[0m (426.83 KB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0055390.0159250.103254
12NaNNaNNaN
9NaNNaNNaN
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" + "
 Non-trainable params: 20 (80.00 B)\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.005539 0.015925 0.103254\n", - "12 NaN NaN NaN\n", - "9 NaN NaN NaN\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" + "96.62368559837341 0.2039716435474497\n", + "74.19592547416687 0.1932276162314687\n", + "79.30283999443054 0.186078923268888\n", + "67.78045701980591 0.1888747778655048\n", + "90.85375475883484 0.20038166436812602\n", + "83.50374150276184 0.198828800729096\n", + "79.80795526504517 0.1848691783440326\n", + "59.527796268463135 0.18538958325425756\n", + "73.05465388298035 0.17813436933288326\n", + "77.17678189277649 0.185736715040992\n" ] }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0002150.0005780.003284
12NaNNaNNaN
9NaNNaNNaN
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" + "
Model: \"GRU\"\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.000215 0.000578 0.003284\n", - "12 NaN NaN NaN\n", - "9 NaN NaN NaN\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" + "\u001b[1mModel: \"GRU\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training Times\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+       "│ gru_20 (GRU)                    │ (None, 45, 100)        │        36,300 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ gru_21 (GRU)                    │ (None, 100)            │        60,600 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dense_40 (Dense)                │ (None, 3)              │           303 │\n",
+       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", + "│ gru_20 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ gru_21 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dense_40 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m303\u001b[0m │\n", + "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
15148.76306842.47829633.66112
12NaNNaNNaN
9NaNNaNNaN
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" + "
 Total params: 97,203 (379.70 KB)\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 148.763068 42.478296 33.66112\n", - "12 NaN NaN NaN\n", - "9 NaN NaN NaN\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,203\u001b[0m (379.70 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] + "data": { + "text/html": [ + "
 Trainable params: 97,203 (379.70 KB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,203\u001b[0m (379.70 KB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
1514.6747052.7050711.335022
12NaNNaNNaN
9NaNNaNNaN
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 14.674705 2.705071 1.335022\n", - "12 NaN NaN NaN\n", - "9 NaN NaN NaN\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "32.23410701751709 0.24445524304420677\n", + "34.76943516731262 0.2201955661466759\n", + "25.166553735733032 0.23882125234574236\n", + "28.084814310073853 0.2344194906558783\n", + "27.79819416999817 0.235279604982632\n", + "31.362367868423462 0.23122380522590277\n", + "24.73177719116211 0.2426694875126042\n", + "26.578373432159424 0.23963082071966313\n", + "27.062561988830566 0.24032075216240234\n", + "28.228745222091675 0.24178121308010916\n" + ] + }, { "data": { "text/html": [ - "
Model: \"TKAN\"\n",
+       "
Model: \"LSTM\"\n",
        "
\n" ], "text/plain": [ - "\u001b[1mModel: \"TKAN\"\u001b[0m\n" + "\u001b[1mModel: \"LSTM\"\u001b[0m\n" ] }, "metadata": {}, @@ -1562,11 +1177,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ tkan_10 (TKAN)                  │ (None, 60, 100)        │        41,750 │\n",
+       "│ lstm_20 (LSTM)                  │ (None, 45, 100)        │        48,000 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ tkan_11 (TKAN)                  │ (None, 100)            │        67,670 │\n",
+       "│ lstm_21 (LSTM)                  │ (None, 100)            │        80,400 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_15 (Dense)                │ (None, 12)             │         1,212 │\n",
+       "│ dense_50 (Dense)                │ (None, 3)              │           303 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -1574,11 +1189,11 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ tkan_10 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m60\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,750\u001b[0m │\n", + "│ lstm_20 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ tkan_11 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n", + "│ lstm_21 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_15 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,212\u001b[0m │\n", + "│ dense_50 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m303\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -1588,11 +1203,11 @@ { "data": { "text/html": [ - "
 Total params: 110,632 (432.16 KB)\n",
+       "
 Total params: 128,703 (502.75 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m110,632\u001b[0m (432.16 KB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m128,703\u001b[0m (502.75 KB)\n" ] }, "metadata": {}, @@ -1601,11 +1216,11 @@ { "data": { "text/html": [ - "
 Trainable params: 110,612 (432.08 KB)\n",
+       "
 Trainable params: 128,703 (502.75 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m110,612\u001b[0m (432.08 KB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m128,703\u001b[0m (502.75 KB)\n" ] }, "metadata": {}, @@ -1614,11 +1229,11 @@ { "data": { "text/html": [ - "
 Non-trainable params: 20 (80.00 B)\n",
+       "
 Non-trainable params: 0 (0.00 B)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n" + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, @@ -1628,21 +1243,26 @@ "name": "stdout", "output_type": "stream", "text": [ - "131.22357177734375 0.09087923067578507\n", - "141.15374612808228 0.10076559080540619\n", - "121.34675312042236 0.09357512681452108\n", - "129.00918769836426 0.10084381545148026\n", - "135.5328950881958 0.09396223859149976\n" + "20.623400688171387 0.14482767971630928\n", + "21.055927991867065 0.07395821838488396\n", + "24.374985218048096 0.022805938949401933\n", + "21.713353157043457 0.0749404329796987\n", + "21.184755563735962 0.1522887184364525\n", + "18.015154600143433 0.16326070883053923\n", + "18.17797017097473 0.1804372350265144\n", + "17.993833541870117 0.18266789728362973\n", + "18.728025197982788 0.14342017896777437\n", + "20.97398853302002 -0.002893617498768264\n" ] }, { "data": { "text/html": [ - "
Model: \"GRU\"\n",
+       "
Model: \"TKAN\"\n",
        "
\n" ], "text/plain": [ - "\u001b[1mModel: \"GRU\"\u001b[0m\n" + "\u001b[1mModel: \"TKAN\"\u001b[0m\n" ] }, "metadata": {}, @@ -1654,11 +1274,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ gru_10 (GRU)                    │ (None, 60, 100)        │        36,300 │\n",
+       "│ tkan_40 (TKAN)                  │ (None, 45, 100)        │        41,316 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ gru_11 (GRU)                    │ (None, 100)            │        60,600 │\n",
+       "│ tkan_41 (TKAN)                  │ (None, 100)            │        67,670 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_20 (Dense)                │ (None, 12)             │         1,212 │\n",
+       "│ dense_60 (Dense)                │ (None, 6)              │           606 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -1666,11 +1286,11 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ gru_10 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m60\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n", + "│ tkan_40 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,316\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ gru_11 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n", + "│ tkan_41 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_20 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,212\u001b[0m │\n", + "│ dense_60 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m606\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -1680,11 +1300,11 @@ { "data": { "text/html": [ - "
 Total params: 98,112 (383.25 KB)\n",
+       "
 Total params: 109,592 (428.09 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m98,112\u001b[0m (383.25 KB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m109,592\u001b[0m (428.09 KB)\n" ] }, "metadata": {}, @@ -1693,11 +1313,11 @@ { "data": { "text/html": [ - "
 Trainable params: 98,112 (383.25 KB)\n",
+       "
 Trainable params: 109,572 (428.02 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m98,112\u001b[0m (383.25 KB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m109,572\u001b[0m (428.02 KB)\n" ] }, "metadata": {}, @@ -1706,11 +1326,11 @@ { "data": { "text/html": [ - "
 Non-trainable params: 0 (0.00 B)\n",
+       "
 Non-trainable params: 20 (80.00 B)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n" ] }, "metadata": {}, @@ -1720,21 +1340,26 @@ "name": "stdout", "output_type": "stream", "text": [ - "38.70786952972412 0.08947185062429679\n", - "35.21390247344971 0.10739726688893424\n", - "48.478371143341064 -0.05339822582392689\n", - "40.725359201431274 0.06710397964579173\n", - "43.54443335533142 0.04774073080153538\n" + "65.0717043876648 0.12467616278770983\n", + "90.68930077552795 0.13929252527937125\n", + "76.80289697647095 0.12103943728777532\n", + "84.18350672721863 0.13730972361175564\n", + "59.469563484191895 0.10800808332018592\n", + "80.92852973937988 0.13700481460829453\n", + "85.58175778388977 0.1324956475675861\n", + "78.75041937828064 0.13073665118121205\n", + "80.78611373901367 0.12963646594644282\n", + "73.53622269630432 0.12723135400013533\n" ] }, { "data": { "text/html": [ - "
Model: \"LSTM\"\n",
+       "
Model: \"GRU\"\n",
        "
\n" ], "text/plain": [ - "\u001b[1mModel: \"LSTM\"\u001b[0m\n" + "\u001b[1mModel: \"GRU\"\u001b[0m\n" ] }, "metadata": {}, @@ -1746,11 +1371,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ lstm_10 (LSTM)                  │ (None, 60, 100)        │        48,000 │\n",
+       "│ gru_40 (GRU)                    │ (None, 45, 100)        │        36,300 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ lstm_11 (LSTM)                  │ (None, 100)            │        80,400 │\n",
+       "│ gru_41 (GRU)                    │ (None, 100)            │        60,600 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_25 (Dense)                │ (None, 12)             │         1,212 │\n",
+       "│ dense_70 (Dense)                │ (None, 6)              │           606 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -1758,11 +1383,11 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ lstm_10 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m60\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n", + "│ gru_40 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ lstm_11 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n", + "│ gru_41 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_25 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,212\u001b[0m │\n", + "│ dense_70 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m606\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -1772,11 +1397,11 @@ { "data": { "text/html": [ - "
 Total params: 129,612 (506.30 KB)\n",
+       "
 Total params: 97,506 (380.88 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,612\u001b[0m (506.30 KB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,506\u001b[0m (380.88 KB)\n" ] }, "metadata": {}, @@ -1785,11 +1410,11 @@ { "data": { "text/html": [ - "
 Trainable params: 129,612 (506.30 KB)\n",
+       "
 Trainable params: 97,506 (380.88 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,612\u001b[0m (506.30 KB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,506\u001b[0m (380.88 KB)\n" ] }, "metadata": {}, @@ -1812,1595 +1437,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "28.626214027404785 -0.1006527298304624\n", - "35.30322313308716 -0.24341912522819487\n", - "27.737242698669434 -0.5102580838684908\n", - "32.53866934776306 -0.7130992597534002\n", - "34.27140927314758 -0.24619645358326023\n", - "R2 scores\n", - "Means:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0932910.078430-0.231738
120.0960050.051663-0.362725
9NaNNaNNaN
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.093291 0.078430 -0.231738\n", - "12 0.096005 0.051663 -0.362725\n", - "9 NaN NaN NaN\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0677050.0681940.078686
120.0673210.0688310.082159
9NaNNaNNaN
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.067705 0.068194 0.078686\n", - "12 0.067321 0.068831 0.082159\n", - "9 NaN NaN NaN\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Std:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0055390.0159250.103254
120.0040600.0562630.219555
9NaNNaNNaN
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.005539 0.015925 0.103254\n", - "12 0.004060 0.056263 0.219555\n", - "9 NaN NaN NaN\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0002150.0005780.003284
120.0001480.0019780.006539
9NaNNaNNaN
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.000215 0.000578 0.003284\n", - "12 0.000148 0.001978 0.006539\n", - "9 NaN NaN NaN\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training Times\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
15148.76306842.47829633.661120
12131.65323141.33398731.695352
9NaNNaNNaN
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 148.763068 42.478296 33.661120\n", - "12 131.653231 41.333987 31.695352\n", - "9 NaN NaN NaN\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
1514.6747052.7050711.335022
126.6137834.4866613.014970
9NaNNaNNaN
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 14.674705 2.705071 1.335022\n", - "12 6.613783 4.486661 3.014970\n", - "9 NaN NaN NaN\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
Model: \"TKAN\"\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1mModel: \"TKAN\"\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
-       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ tkan_20 (TKAN)                  │ (None, 45, 100)        │        41,750 │\n",
-       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ tkan_21 (TKAN)                  │ (None, 100)            │        67,670 │\n",
-       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_30 (Dense)                │ (None, 9)              │           909 │\n",
-       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
-       "
\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ tkan_20 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,750\u001b[0m │\n", - "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ tkan_21 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n", - "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_30 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m909\u001b[0m │\n", - "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Total params: 110,329 (430.97 KB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m110,329\u001b[0m (430.97 KB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Trainable params: 110,309 (430.89 KB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m110,309\u001b[0m (430.89 KB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Non-trainable params: 20 (80.00 B)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "104.23471403121948 0.113891478849783\n", - "89.38614058494568 0.10592790984288752\n", - "106.43402981758118 0.10938055442300718\n", - "135.6341907978058 0.11293714586336041\n", - "125.27908515930176 0.11555530130229706\n" - ] - }, - { - "data": { - "text/html": [ - "
Model: \"GRU\"\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1mModel: \"GRU\"\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
-       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ gru_20 (GRU)                    │ (None, 45, 100)        │        36,300 │\n",
-       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ gru_21 (GRU)                    │ (None, 100)            │        60,600 │\n",
-       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_35 (Dense)                │ (None, 9)              │           909 │\n",
-       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
-       "
\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ gru_20 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n", - "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ gru_21 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n", - "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_35 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m909\u001b[0m │\n", - "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Total params: 97,809 (382.07 KB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,809\u001b[0m (382.07 KB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Trainable params: 97,809 (382.07 KB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,809\u001b[0m (382.07 KB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Non-trainable params: 0 (0.00 B)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "29.72066307067871 0.1249179784635558\n", - "33.5255126953125 0.07318165013466399\n", - "38.62114071846008 -0.023410033110204745\n", - "35.293763875961304 0.0894183448972399\n", - "43.56008291244507 -0.023466014859413426\n" - ] - }, - { - "data": { - "text/html": [ - "
Model: \"LSTM\"\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1mModel: \"LSTM\"\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
-       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ lstm_20 (LSTM)                  │ (None, 45, 100)        │        48,000 │\n",
-       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ lstm_21 (LSTM)                  │ (None, 100)            │        80,400 │\n",
-       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_40 (Dense)                │ (None, 9)              │           909 │\n",
-       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
-       "
\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ lstm_20 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n", - "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ lstm_21 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n", - "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_40 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m909\u001b[0m │\n", - "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Total params: 129,309 (505.11 KB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,309\u001b[0m (505.11 KB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Trainable params: 129,309 (505.11 KB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,309\u001b[0m (505.11 KB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Non-trainable params: 0 (0.00 B)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "25.49171733856201 -0.3036703338256579\n", - "27.654677867889404 -0.14078671841214707\n", - "22.398143529891968 -0.0803994153254022\n", - "25.16021418571472 -0.06587595121316724\n", - "23.790441751480103 -0.025310879036213083\n", - "R2 scores\n", - "Means:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0932910.078430-0.231738
120.0960050.051663-0.362725
90.1115380.048128-0.123209
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.093291 0.078430 -0.231738\n", - "12 0.096005 0.051663 -0.362725\n", - "9 0.111538 0.048128 -0.123209\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0677050.0681940.078686
120.0673210.0688310.082159
90.0667170.0688820.074797
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.067705 0.068194 0.078686\n", - "12 0.067321 0.068831 0.082159\n", - "9 0.066717 0.068882 0.074797\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Std:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0055390.0159250.103254
120.0040600.0562630.219555
90.0034570.0607830.097549
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.005539 0.015925 0.103254\n", - "12 0.004060 0.056263 0.219555\n", - "9 0.003457 0.060783 0.097549\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0002150.0005780.003284
120.0001480.0019780.006539
90.0001310.0021420.003162
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.000215 0.000578 0.003284\n", - "12 0.000148 0.001978 0.006539\n", - "9 0.000131 0.002142 0.003162\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training Times\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
15148.76306842.47829633.661120
12131.65323141.33398731.695352
9112.19363236.14423324.899039
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 148.763068 42.478296 33.661120\n", - "12 131.653231 41.333987 31.695352\n", - "9 112.193632 36.144233 24.899039\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
1514.6747052.7050711.335022
126.6137834.4866613.014970
916.3547424.6898431.760482
6NaNNaNNaN
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 14.674705 2.705071 1.335022\n", - "12 6.613783 4.486661 3.014970\n", - "9 16.354742 4.689843 1.760482\n", - "6 NaN NaN NaN\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
Model: \"TKAN\"\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1mModel: \"TKAN\"\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
-       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ tkan_30 (TKAN)                  │ (None, 45, 100)        │        41,750 │\n",
-       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ tkan_31 (TKAN)                  │ (None, 100)            │        67,670 │\n",
-       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_45 (Dense)                │ (None, 6)              │           606 │\n",
-       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
-       "
\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ tkan_30 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,750\u001b[0m │\n", - "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ tkan_31 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n", - "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_45 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m606\u001b[0m │\n", - "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Total params: 110,026 (429.79 KB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m110,026\u001b[0m (429.79 KB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Trainable params: 110,006 (429.71 KB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m110,006\u001b[0m (429.71 KB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Non-trainable params: 20 (80.00 B)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "77.3258445262909 0.1280814725052942\n", - "114.44928884506226 0.14017116370653385\n", - "104.74958324432373 0.126116634620875\n", - "101.73448276519775 0.13543281159237028\n", - "88.49818539619446 0.13133054407465156\n" - ] - }, - { - "data": { - "text/html": [ - "
Model: \"GRU\"\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1mModel: \"GRU\"\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
-       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ gru_30 (GRU)                    │ (None, 45, 100)        │        36,300 │\n",
-       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ gru_31 (GRU)                    │ (None, 100)            │        60,600 │\n",
-       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_50 (Dense)                │ (None, 6)              │           606 │\n",
-       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
-       "
\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ gru_30 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n", - "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ gru_31 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n", - "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_50 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m606\u001b[0m │\n", - "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Total params: 97,506 (380.88 KB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,506\u001b[0m (380.88 KB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Trainable params: 97,506 (380.88 KB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,506\u001b[0m (380.88 KB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Non-trainable params: 0 (0.00 B)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "40.4565007686615 0.13588272560079817\n", - "32.21065616607666 0.150907506214383\n", - "41.915128231048584 0.09900284704797373\n", - "40.109740018844604 0.06818254449404637\n", - "40.55113649368286 0.14480364148103195\n" + "26.26027750968933 0.1602740337153421\n", + "30.87123155593872 0.12435047691678952\n", + "30.425692319869995 0.10110813722797496\n", + "29.073087215423584 0.13987663037711742\n", + "27.583446502685547 0.12691294524178082\n", + "30.0588481426239 0.09291668323995494\n", + "30.896180629730225 0.13374339087226397\n", + "26.069262266159058 0.13703060900398636\n", + "26.07837224006653 0.14521618469459105\n", + "28.18198537826538 0.13697203267595567\n" ] }, { @@ -3422,11 +1468,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ lstm_30 (LSTM)                  │ (None, 45, 100)        │        48,000 │\n",
+       "│ lstm_40 (LSTM)                  │ (None, 45, 100)        │        48,000 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ lstm_31 (LSTM)                  │ (None, 100)            │        80,400 │\n",
+       "│ lstm_41 (LSTM)                  │ (None, 100)            │        80,400 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_55 (Dense)                │ (None, 6)              │           606 │\n",
+       "│ dense_80 (Dense)                │ (None, 6)              │           606 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -3434,435 +1480,51 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ lstm_30 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n", + "│ lstm_40 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ lstm_31 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n", + "│ lstm_41 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_55 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m606\u001b[0m │\n", + "│ dense_80 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m606\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Total params: 129,006 (503.93 KB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,006\u001b[0m (503.93 KB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Trainable params: 129,006 (503.93 KB)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,006\u001b[0m (503.93 KB)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
 Non-trainable params: 0 (0.00 B)\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "23.99752426147461 0.03173612065905046\n", - "25.493926286697388 -0.05948194958746694\n", - "33.9563672542572 0.12418423914604533\n", - "25.076443910598755 -0.09484002694483258\n", - "23.51260232925415 -0.021994532852705695\n", - "R2 scores\n", - "Means:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0932910.078430-0.231738
120.0960050.051663-0.362725
90.1115380.048128-0.123209
60.1322270.119756-0.004079
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.093291 0.078430 -0.231738\n", - "12 0.096005 0.051663 -0.362725\n", - "9 0.111538 0.048128 -0.123209\n", - "6 0.132227 0.119756 -0.004079\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0677050.0681940.078686
120.0673210.0688310.082159
90.0667170.0688820.074797
60.0661310.0664700.070901
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" - ], - "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.067705 0.068194 0.078686\n", - "12 0.067321 0.068831 0.082159\n", - "9 0.066717 0.068882 0.074797\n", - "6 0.066131 0.066470 0.070901\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Std:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0055390.0159250.103254
120.0040600.0562630.219555
90.0034570.0607830.097549
60.0050740.0314590.076632
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Total params: 129,006 (503.93 KB)\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.005539 0.015925 0.103254\n", - "12 0.004060 0.056263 0.219555\n", - "9 0.003457 0.060783 0.097549\n", - "6 0.005074 0.031459 0.076632\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,006\u001b[0m (503.93 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] + "data": { + "text/html": [ + "
 Trainable params: 129,006 (503.93 KB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,006\u001b[0m (503.93 KB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0002150.0005780.003284
120.0001480.0019780.006539
90.0001310.0021420.003162
60.0001990.0011510.002695
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.000215 0.000578 0.003284\n", - "12 0.000148 0.001978 0.006539\n", - "9 0.000131 0.002142 0.003162\n", - "6 0.000199 0.001151 0.002695\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, @@ -3872,186 +1534,81 @@ "name": "stdout", "output_type": "stream", "text": [ - "Training Times\n" + "25.56489086151123 0.07762226752057655\n", + "17.17194938659668 -0.030710740566574007\n", + "22.08793616294861 0.12124885788661292\n", + "18.8874089717865 -0.19992596769356452\n", + "18.460254669189453 -0.03568727060662086\n", + "19.47718334197998 -0.1732379052017987\n", + "18.813477754592896 -0.0515297041407121\n", + "18.96723985671997 -0.00603457861723435\n", + "17.978565216064453 0.06545331792210407\n", + "18.76132583618164 -0.21833925712011495\n" ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] + "data": { + "text/html": [ + "
Model: \"TKAN\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"TKAN\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
15148.76306842.47829633.661120
12131.65323141.33398731.695352
9112.19363236.14423324.899039
697.35147739.04863226.407373
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+       "│ tkan_60 (TKAN)                  │ (None, 45, 100)        │        41,316 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ tkan_61 (TKAN)                  │ (None, 100)            │        67,670 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dense_90 (Dense)                │ (None, 9)              │           909 │\n",
+       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 148.763068 42.478296 33.661120\n", - "12 131.653231 41.333987 31.695352\n", - "9 112.193632 36.144233 24.899039\n", - "6 97.351477 39.048632 26.407373\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", + "│ tkan_60 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,316\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ tkan_61 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dense_90 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m909\u001b[0m │\n", + "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] + "data": { + "text/html": [ + "
 Total params: 109,895 (429.28 KB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m109,895\u001b[0m (429.28 KB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
1514.6747052.7050711.335022
126.6137834.4866613.014970
916.3547424.6898431.760482
613.0052653.4739113.841358
3NaNNaNNaN
1NaNNaNNaN
\n", - "
" + "
 Trainable params: 109,875 (429.20 KB)\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 14.674705 2.705071 1.335022\n", - "12 6.613783 4.486661 3.014970\n", - "9 16.354742 4.689843 1.760482\n", - "6 13.005265 3.473911 3.841358\n", - "3 NaN NaN NaN\n", - "1 NaN NaN NaN" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m109,875\u001b[0m (429.20 KB)\n" ] }, "metadata": {}, @@ -4060,11 +1617,40 @@ { "data": { "text/html": [ - "
Model: \"TKAN\"\n",
+       "
 Non-trainable params: 20 (80.00 B)\n",
        "
\n" ], "text/plain": [ - "\u001b[1mModel: \"TKAN\"\u001b[0m\n" + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "79.56117129325867 0.10413605110577702\n", + "95.19019293785095 0.10785571960353318\n", + "62.058950901031494 0.1022740221679086\n", + "57.313427448272705 0.09659038594513981\n", + "95.39828181266785 0.11798299575914914\n", + "109.50233817100525 0.11471786200051105\n", + "91.66243028640747 0.10141240448038549\n", + "83.4484133720398 0.11271609311366479\n", + "80.45564818382263 0.10729036960326731\n", + "89.72633719444275 0.10601871726343631\n" + ] + }, + { + "data": { + "text/html": [ + "
Model: \"GRU\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"GRU\"\u001b[0m\n" ] }, "metadata": {}, @@ -4076,11 +1662,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ tkan_40 (TKAN)                  │ (None, 45, 100)        │        41,750 │\n",
+       "│ gru_60 (GRU)                    │ (None, 45, 100)        │        36,300 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ tkan_41 (TKAN)                  │ (None, 100)            │        67,670 │\n",
+       "│ gru_61 (GRU)                    │ (None, 100)            │        60,600 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_60 (Dense)                │ (None, 3)              │           303 │\n",
+       "│ dense_100 (Dense)               │ (None, 9)              │           909 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -4088,11 +1674,11 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ tkan_40 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,750\u001b[0m │\n", + "│ gru_60 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ tkan_41 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n", + "│ gru_61 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_60 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m303\u001b[0m │\n", + "│ dense_100 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m909\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -4102,11 +1688,11 @@ { "data": { "text/html": [ - "
 Total params: 109,723 (428.61 KB)\n",
+       "
 Total params: 97,809 (382.07 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m109,723\u001b[0m (428.61 KB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,809\u001b[0m (382.07 KB)\n" ] }, "metadata": {}, @@ -4115,11 +1701,11 @@ { "data": { "text/html": [ - "
 Trainable params: 109,703 (428.53 KB)\n",
+       "
 Trainable params: 97,809 (382.07 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m109,703\u001b[0m (428.53 KB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,809\u001b[0m (382.07 KB)\n" ] }, "metadata": {}, @@ -4128,11 +1714,11 @@ { "data": { "text/html": [ - "
 Non-trainable params: 20 (80.00 B)\n",
+       "
 Non-trainable params: 0 (0.00 B)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n" + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, @@ -4142,21 +1728,26 @@ "name": "stdout", "output_type": "stream", "text": [ - "92.45071959495544 0.1916116569043532\n", - "94.640451669693 0.19454637958748155\n", - "73.71659445762634 0.19117099482549302\n", - "95.40752124786377 0.19110640326041053\n", - "85.31889510154724 0.19614372723219617\n" + "22.387739181518555 0.12872665220206386\n", + "30.9358811378479 0.07007529147404491\n", + "30.65935730934143 0.06094348103987125\n", + "22.681428909301758 0.11531462257235203\n", + "30.2576744556427 -0.009184511914249627\n", + "23.401977062225342 0.1145808233763385\n", + "29.867743015289307 -0.0020916084892752293\n", + "21.44379758834839 0.13675731715603687\n", + "32.31744432449341 0.08318358609980472\n", + "36.52803301811218 -0.02434265870166709\n" ] }, { "data": { "text/html": [ - "
Model: \"GRU\"\n",
+       "
Model: \"LSTM\"\n",
        "
\n" ], "text/plain": [ - "\u001b[1mModel: \"GRU\"\u001b[0m\n" + "\u001b[1mModel: \"LSTM\"\u001b[0m\n" ] }, "metadata": {}, @@ -4168,11 +1759,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ gru_40 (GRU)                    │ (None, 45, 100)        │        36,300 │\n",
+       "│ lstm_60 (LSTM)                  │ (None, 45, 100)        │        48,000 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ gru_41 (GRU)                    │ (None, 100)            │        60,600 │\n",
+       "│ lstm_61 (LSTM)                  │ (None, 100)            │        80,400 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_65 (Dense)                │ (None, 3)              │           303 │\n",
+       "│ dense_110 (Dense)               │ (None, 9)              │           909 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -4180,11 +1771,11 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ gru_40 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n", + "│ lstm_60 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ gru_41 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n", + "│ lstm_61 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_65 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m303\u001b[0m │\n", + "│ dense_110 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m909\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -4194,11 +1785,11 @@ { "data": { "text/html": [ - "
 Total params: 97,203 (379.70 KB)\n",
+       "
 Total params: 129,309 (505.11 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,203\u001b[0m (379.70 KB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,309\u001b[0m (505.11 KB)\n" ] }, "metadata": {}, @@ -4207,11 +1798,11 @@ { "data": { "text/html": [ - "
 Trainable params: 97,203 (379.70 KB)\n",
+       "
 Trainable params: 129,309 (505.11 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,203\u001b[0m (379.70 KB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,309\u001b[0m (505.11 KB)\n" ] }, "metadata": {}, @@ -4234,21 +1825,26 @@ "name": "stdout", "output_type": "stream", "text": [ - "34.47401285171509 0.23517545876629728\n", - "32.61530804634094 0.23820733972132835\n", - "38.93267202377319 0.23710632242677557\n", - "32.61954689025879 0.23525012980273963\n", - "33.25591826438904 0.24304388963816356\n" + "20.542985439300537 -0.058979852588216475\n", + "19.61976170539856 -0.07973171479104875\n", + "19.607908487319946 -0.15736860736964103\n", + "17.53559136390686 -0.09773126198658748\n", + "19.434728860855103 -0.39964713298492527\n", + "20.465200901031494 -0.2428380152707189\n", + "18.985889673233032 -0.5012525957302113\n", + "20.644855260849 -0.2738693890416318\n", + "19.245380401611328 -0.05146921829097126\n", + "19.828939199447632 -0.0008553086817381464\n" ] }, { "data": { "text/html": [ - "
Model: \"LSTM\"\n",
+       "
Model: \"TKAN\"\n",
        "
\n" ], "text/plain": [ - "\u001b[1mModel: \"LSTM\"\u001b[0m\n" + "\u001b[1mModel: \"TKAN\"\u001b[0m\n" ] }, "metadata": {}, @@ -4260,11 +1856,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ lstm_40 (LSTM)                  │ (None, 45, 100)        │        48,000 │\n",
+       "│ tkan_80 (TKAN)                  │ (None, 60, 100)        │        41,316 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ lstm_41 (LSTM)                  │ (None, 100)            │        80,400 │\n",
+       "│ tkan_81 (TKAN)                  │ (None, 100)            │        67,670 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_70 (Dense)                │ (None, 3)              │           303 │\n",
+       "│ dense_120 (Dense)               │ (None, 12)             │         1,212 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -4272,11 +1868,11 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ lstm_40 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n", + "│ tkan_80 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m60\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,316\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ lstm_41 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n", + "│ tkan_81 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_70 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m303\u001b[0m │\n", + "│ dense_120 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,212\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -4286,11 +1882,11 @@ { "data": { "text/html": [ - "
 Total params: 128,703 (502.75 KB)\n",
+       "
 Total params: 110,198 (430.46 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m128,703\u001b[0m (502.75 KB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m110,198\u001b[0m (430.46 KB)\n" ] }, "metadata": {}, @@ -4299,11 +1895,11 @@ { "data": { "text/html": [ - "
 Trainable params: 128,703 (502.75 KB)\n",
+       "
 Trainable params: 110,178 (430.38 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m128,703\u001b[0m (502.75 KB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m110,178\u001b[0m (430.38 KB)\n" ] }, "metadata": {}, @@ -4312,11 +1908,11 @@ { "data": { "text/html": [ - "
 Non-trainable params: 0 (0.00 B)\n",
+       "
 Non-trainable params: 20 (80.00 B)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n" ] }, "metadata": {}, @@ -4326,381 +1922,94 @@ "name": "stdout", "output_type": "stream", "text": [ - "27.456106424331665 0.05745848167891302\n", - "25.54291844367981 0.04933120546474171\n", - "27.04043936729431 0.14269677898499963\n", - "24.304097175598145 0.16428378453298487\n", - "26.050607442855835 0.13263322060964924\n", - "R2 scores\n", - "Means:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" + "102.03934359550476 0.0950019998069378\n", + "121.27882838249207 0.10264306639669311\n", + "108.56248140335083 0.0990784027938416\n", + "77.42332911491394 0.0915531316015012\n", + "108.50404167175293 0.09352464587821978\n", + "95.75039339065552 0.092634008105403\n", + "102.65606546401978 0.09820662504181471\n", + "130.8202440738678 0.09765221932195063\n", + "113.90040755271912 0.09989627016065132\n", + "118.27212023735046 0.1041245428645089\n" ] }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0932910.078430-0.231738
120.0960050.051663-0.362725
90.1115380.048128-0.123209
60.1322270.119756-0.004079
30.1929160.2377570.109281
1NaNNaNNaN
\n", - "
" + "
Model: \"GRU\"\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.093291 0.078430 -0.231738\n", - "12 0.096005 0.051663 -0.362725\n", - "9 0.111538 0.048128 -0.123209\n", - "6 0.132227 0.119756 -0.004079\n", - "3 0.192916 0.237757 0.109281\n", - "1 NaN NaN NaN" + "\u001b[1mModel: \"GRU\"\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0677050.0681940.078686
120.0673210.0688310.082159
90.0667170.0688820.074797
60.0661310.0664700.070901
30.0635600.0616880.066594
1NaNNaNNaN
\n", - "
" + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+       "│ gru_80 (GRU)                    │ (None, 60, 100)        │        36,300 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ gru_81 (GRU)                    │ (None, 100)            │        60,600 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dense_130 (Dense)               │ (None, 12)             │         1,212 │\n",
+       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.067705 0.068194 0.078686\n", - "12 0.067321 0.068831 0.082159\n", - "9 0.066717 0.068882 0.074797\n", - "6 0.066131 0.066470 0.070901\n", - "3 0.063560 0.061688 0.066594\n", - "1 NaN NaN NaN" + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", + "│ gru_80 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m60\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ gru_81 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dense_130 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,212\u001b[0m │\n", + "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Std:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0055390.0159250.103254
120.0040600.0562630.219555
90.0034570.0607830.097549
60.0050740.0314590.076632
30.0020540.0028820.046833
1NaNNaNNaN
\n", - "
" + "
 Total params: 98,112 (383.25 KB)\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.005539 0.015925 0.103254\n", - "12 0.004060 0.056263 0.219555\n", - "9 0.003457 0.060783 0.097549\n", - "6 0.005074 0.031459 0.076632\n", - "3 0.002054 0.002882 0.046833\n", - "1 NaN NaN NaN" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m98,112\u001b[0m (383.25 KB)\n" ] }, "metadata": {}, "output_type": "display_data" }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] - }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
150.0002150.0005780.003284
120.0001480.0019780.006539
90.0001310.0021420.003162
60.0001990.0011510.002695
30.0000940.0001120.001722
1NaNNaNNaN
\n", - "
" + "
 Trainable params: 98,112 (383.25 KB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m98,112\u001b[0m (383.25 KB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 0.000215 0.000578 0.003284\n", - "12 0.000148 0.001978 0.006539\n", - "9 0.000131 0.002142 0.003162\n", - "6 0.000199 0.001151 0.002695\n", - "3 0.000094 0.000112 0.001722\n", - "1 NaN NaN NaN" + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, @@ -4710,191 +2019,115 @@ "name": "stdout", "output_type": "stream", "text": [ - "Training Times\n" + "35.47221755981445 0.04088546661722716\n", + "36.82418465614319 0.03298667710271531\n", + "29.090290307998657 0.08667586160493074\n", + "27.488890886306763 0.09733002632430683\n", + "30.277069330215454 0.10445165105456156\n", + "26.764411687850952 0.11400616888042837\n", + "35.36870241165161 0.02046360232907878\n", + "26.94648003578186 0.11067681945648022\n", + "35.716086864471436 0.02628936001538192\n", + "32.919156312942505 0.05842798641178421\n" ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", - " return _methods._mean(a, axis=axis, dtype=dtype,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] + "data": { + "text/html": [ + "
Model: \"LSTM\"\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mModel: \"LSTM\"\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
15148.76306842.47829633.661120
12131.65323141.33398731.695352
9112.19363236.14423324.899039
697.35147739.04863226.407373
388.30683634.37949226.078834
1NaNNaNNaN
\n", - "
" + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+       "┃ Layer (type)                     Output Shape                  Param # ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+       "│ lstm_80 (LSTM)                  │ (None, 60, 100)        │        48,000 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ lstm_81 (LSTM)                  │ (None, 100)            │        80,400 │\n",
+       "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+       "│ dense_140 (Dense)               │ (None, 12)             │         1,212 │\n",
+       "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 148.763068 42.478296 33.661120\n", - "12 131.653231 41.333987 31.695352\n", - "9 112.193632 36.144233 24.899039\n", - "6 97.351477 39.048632 26.407373\n", - "3 88.306836 34.379492 26.078834\n", - "1 NaN NaN NaN" + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", + "│ lstm_80 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m60\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ lstm_81 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n", + "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", + "│ dense_140 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,212\u001b[0m │\n", + "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n", - " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n", - " arrmean = um.true_divide(arrmean, div, out=arrmean,\n", - "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n", - " ret = ret.dtype.type(ret / rcount)\n" - ] + "data": { + "text/html": [ + "
 Total params: 129,612 (506.30 KB)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,612\u001b[0m (506.30 KB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "data": { "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
TKANGRULSTM
1514.6747052.7050711.335022
126.6137834.4866613.014970
916.3547424.6898431.760482
613.0052653.4739113.841358
38.1179942.3753971.118862
1NaNNaNNaN
\n", - "
" + "
 Trainable params: 129,612 (506.30 KB)\n",
+       "
\n" ], "text/plain": [ - " TKAN GRU LSTM\n", - "15 14.674705 2.705071 1.335022\n", - "12 6.613783 4.486661 3.014970\n", - "9 16.354742 4.689843 1.760482\n", - "6 13.005265 3.473911 3.841358\n", - "3 8.117994 2.375397 1.118862\n", - "1 NaN NaN NaN" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,612\u001b[0m (506.30 KB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
 Non-trainable params: 0 (0.00 B)\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "26.069539546966553 -0.36695473176783655\n", + "21.807262659072876 -0.25014631415290206\n", + "20.902458667755127 -0.11645941888847307\n", + "22.4891197681427 -0.052407595508665694\n", + "22.018208742141724 -0.1754222997379895\n", + "23.925899028778076 -0.444045746274295\n", + "23.6626615524292 -0.2362643247514915\n", + "20.98034143447876 -0.16545529402468354\n", + "22.008980751037598 -0.23959377820038377\n", + "21.912462949752808 -0.5002557724294028\n" + ] + }, { "data": { "text/html": [ @@ -4914,11 +2147,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ tkan_50 (TKAN)                  │ (None, 45, 100)        │        41,750 │\n",
+       "│ tkan_100 (TKAN)                 │ (None, 75, 100)        │        41,316 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ tkan_51 (TKAN)                  │ (None, 100)            │        67,670 │\n",
+       "│ tkan_101 (TKAN)                 │ (None, 100)            │        67,670 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_75 (Dense)                │ (None, 1)              │           101 │\n",
+       "│ dense_150 (Dense)               │ (None, 15)             │         1,515 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -4926,11 +2159,11 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ tkan_50 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,750\u001b[0m │\n", + "│ tkan_100 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m75\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,316\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ tkan_51 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n", + "│ tkan_101 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_75 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m101\u001b[0m │\n", + "│ dense_150 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m) │ \u001b[38;5;34m1,515\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -4940,11 +2173,11 @@ { "data": { "text/html": [ - "
 Total params: 109,521 (427.82 KB)\n",
+       "
 Total params: 110,501 (431.64 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m109,521\u001b[0m (427.82 KB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m110,501\u001b[0m (431.64 KB)\n" ] }, "metadata": {}, @@ -4953,11 +2186,11 @@ { "data": { "text/html": [ - "
 Trainable params: 109,501 (427.74 KB)\n",
+       "
 Trainable params: 110,481 (431.57 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m109,501\u001b[0m (427.74 KB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m110,481\u001b[0m (431.57 KB)\n" ] }, "metadata": {}, @@ -4980,11 +2213,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "81.33754324913025 0.3022592965618307\n", - "104.95103526115417 0.302846112010449\n", - "111.78685140609741 0.33146529199419217\n", - "64.58520579338074 0.27988703845286933\n", - "115.72550010681152 0.32013992233948596\n" + "108.7913281917572 0.09994023646490877\n", + "156.19587421417236 0.08586536769033742\n", + "151.76872277259827 0.0953234353911566\n", + "109.25939583778381 0.08350316914250061\n", + "120.67966675758362 0.09383066359946372\n", + "118.93469595909119 0.09027677327433754\n", + "117.88046956062317 0.09486514143615367\n", + "106.22660422325134 0.09703190398065814\n", + "154.6233696937561 0.09741707472392\n", + "124.99523186683655 0.09227563879811171\n" ] }, { @@ -5006,11 +2244,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ gru_50 (GRU)                    │ (None, 45, 100)        │        36,300 │\n",
+       "│ gru_100 (GRU)                   │ (None, 75, 100)        │        36,300 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ gru_51 (GRU)                    │ (None, 100)            │        60,600 │\n",
+       "│ gru_101 (GRU)                   │ (None, 100)            │        60,600 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_80 (Dense)                │ (None, 1)              │           101 │\n",
+       "│ dense_160 (Dense)               │ (None, 15)             │         1,515 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -5018,11 +2256,11 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ gru_50 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n", + "│ gru_100 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m75\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ gru_51 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n", + "│ gru_101 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_80 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m101\u001b[0m │\n", + "│ dense_160 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m) │ \u001b[38;5;34m1,515\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -5032,11 +2270,11 @@ { "data": { "text/html": [ - "
 Total params: 97,001 (378.91 KB)\n",
+       "
 Total params: 98,415 (384.43 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,001\u001b[0m (378.91 KB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m98,415\u001b[0m (384.43 KB)\n" ] }, "metadata": {}, @@ -5045,11 +2283,11 @@ { "data": { "text/html": [ - "
 Trainable params: 97,001 (378.91 KB)\n",
+       "
 Trainable params: 98,415 (384.43 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,001\u001b[0m (378.91 KB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m98,415\u001b[0m (384.43 KB)\n" ] }, "metadata": {}, @@ -5072,11 +2310,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "42.188021659851074 0.39814672018793895\n", - "32.78909373283386 0.3956099221775101\n", - "35.24668765068054 0.3982762017641819\n", - "42.99339699745178 0.3956046326974988\n", - "28.0660617351532 0.3989501248082734\n" + "41.711098432540894 0.04633579312704695\n", + "37.61102104187012 0.041959092296909146\n", + "33.6794171333313 0.08124711803653464\n", + "41.39418888092041 0.062273110060791455\n", + "33.090798139572144 0.09644091868374302\n", + "34.905259132385254 0.0688485348099919\n", + "31.54578924179077 0.09070480771844083\n", + "31.649243354797363 0.08222696375881307\n", + "34.72465991973877 0.07408738048462013\n", + "34.364349365234375 0.06764429477726015\n" ] }, { @@ -5098,11 +2341,11 @@ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
        "┃ Layer (type)                     Output Shape                  Param # ┃\n",
        "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
-       "│ lstm_50 (LSTM)                  │ (None, 45, 100)        │        48,000 │\n",
+       "│ lstm_100 (LSTM)                 │ (None, 75, 100)        │        48,000 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ lstm_51 (LSTM)                  │ (None, 100)            │        80,400 │\n",
+       "│ lstm_101 (LSTM)                 │ (None, 100)            │        80,400 │\n",
        "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
-       "│ dense_85 (Dense)                │ (None, 1)              │           101 │\n",
+       "│ dense_170 (Dense)               │ (None, 15)             │         1,515 │\n",
        "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
        "
\n" ], @@ -5110,11 +2353,11 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ lstm_50 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n", + "│ lstm_100 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m75\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ lstm_51 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n", + "│ lstm_101 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", - "│ dense_85 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m101\u001b[0m │\n", + "│ dense_170 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m) │ \u001b[38;5;34m1,515\u001b[0m │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, @@ -5124,11 +2367,11 @@ { "data": { "text/html": [ - "
 Total params: 128,501 (501.96 KB)\n",
+       "
 Total params: 129,915 (507.48 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m128,501\u001b[0m (501.96 KB)\n" + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,915\u001b[0m (507.48 KB)\n" ] }, "metadata": {}, @@ -5137,11 +2380,11 @@ { "data": { "text/html": [ - "
 Trainable params: 128,501 (501.96 KB)\n",
+       "
 Trainable params: 129,915 (507.48 KB)\n",
        "
\n" ], "text/plain": [ - "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m128,501\u001b[0m (501.96 KB)\n" + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,915\u001b[0m (507.48 KB)\n" ] }, "metadata": {}, @@ -5164,11 +2407,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "25.248199701309204 0.3954618995452197\n", - "25.895612478256226 0.36461360004863774\n", - "22.943801403045654 0.39227780049103544\n", - "25.566108465194702 0.3725527679003343\n", - "24.88447666168213 0.3704866575757282\n", + "27.019399642944336 -0.02014558059192651\n", + "24.994324684143066 -0.034652205897714984\n", + "29.875847578048706 -0.2883331388024311\n", + "27.68077278137207 -0.19307423134624493\n", + "32.10626721382141 -0.1967002645401843\n", + "26.022549629211426 -0.10374175022709843\n", + "28.117710828781128 -0.0705739546737177\n", + "30.498241662979126 -0.25929543857022325\n", + "24.907368898391724 -0.09258147039272031\n", + "26.723052978515625 -0.15915120445580425\n", "R2 scores\n", "Means:\n" ] @@ -5201,40 +2449,40 @@ " \n", " \n", " \n", - " 15\n", - " 0.093291\n", - " 0.078430\n", - " -0.231738\n", + " 1\n", + " 0.310078\n", + " 0.394863\n", + " 0.381780\n", " \n", " \n", - " 12\n", - " 0.096005\n", - " 0.051663\n", - " -0.362725\n", + " 3\n", + " 0.190549\n", + " 0.236880\n", + " 0.113571\n", " \n", " \n", - " 9\n", - " 0.111538\n", - " 0.048128\n", - " -0.123209\n", + " 6\n", + " 0.128743\n", + " 0.129840\n", + " -0.045114\n", " \n", " \n", - " 6\n", - " 0.132227\n", - " 0.119756\n", - " -0.004079\n", + " 9\n", + " 0.107099\n", + " 0.067396\n", + " -0.186374\n", " \n", " \n", - " 3\n", - " 0.192916\n", - " 0.237757\n", - " 0.109281\n", + " 12\n", + " 0.097431\n", + " 0.069219\n", + " -0.254701\n", " \n", " \n", - " 1\n", - " 0.307320\n", - " 0.397318\n", - " 0.379079\n", + " 15\n", + " 0.093033\n", + " 0.071177\n", + " -0.141825\n", " \n", " \n", "\n", @@ -5242,12 +2490,12 @@ ], "text/plain": [ " TKAN GRU LSTM\n", - "15 0.093291 0.078430 -0.231738\n", - "12 0.096005 0.051663 -0.362725\n", - "9 0.111538 0.048128 -0.123209\n", - "6 0.132227 0.119756 -0.004079\n", - "3 0.192916 0.237757 0.109281\n", - "1 0.307320 0.397318 0.379079" + "1 0.310078 0.394863 0.381780\n", + "3 0.190549 0.236880 0.113571\n", + "6 0.128743 0.129840 -0.045114\n", + "9 0.107099 0.067396 -0.186374\n", + "12 0.097431 0.069219 -0.254701\n", + "15 0.093033 0.071177 -0.141825" ] }, "metadata": {}, @@ -5281,40 +2529,40 @@ " \n", " \n", " \n", - " 15\n", - " 0.067705\n", - " 0.068194\n", - " 0.078686\n", + " 1\n", + " 0.058833\n", + " 0.055101\n", + " 0.055693\n", " \n", " \n", - " 12\n", - " 0.067321\n", - " 0.068831\n", - " 0.082159\n", + " 3\n", + " 0.063659\n", + " 0.061720\n", + " 0.066402\n", " \n", " \n", - " 9\n", - " 0.066717\n", - " 0.068882\n", - " 0.074797\n", + " 6\n", + " 0.066264\n", + " 0.066109\n", + " 0.072302\n", " \n", " \n", - " 6\n", - " 0.066131\n", - " 0.066470\n", - " 0.070901\n", + " 9\n", + " 0.066893\n", + " 0.068206\n", + " 0.076763\n", " \n", " \n", - " 3\n", - " 0.063560\n", - " 0.061688\n", - " 0.066594\n", + " 12\n", + " 0.067267\n", + " 0.068214\n", + " 0.079032\n", " \n", " \n", - " 1\n", - " 0.058948\n", - " 0.054990\n", - " 0.055813\n", + " 15\n", + " 0.067711\n", + " 0.068456\n", + " 0.075833\n", " \n", " \n", "\n", @@ -5322,12 +2570,12 @@ ], "text/plain": [ " TKAN GRU LSTM\n", - "15 0.067705 0.068194 0.078686\n", - "12 0.067321 0.068831 0.082159\n", - "9 0.066717 0.068882 0.074797\n", - "6 0.066131 0.066470 0.070901\n", - "3 0.063560 0.061688 0.066594\n", - "1 0.058948 0.054990 0.055813" + "1 0.058833 0.055101 0.055693\n", + "3 0.063659 0.061720 0.066402\n", + "6 0.066264 0.066109 0.072302\n", + "9 0.066893 0.068206 0.076763\n", + "12 0.067267 0.068214 0.079032\n", + "15 0.067711 0.068456 0.075833" ] }, "metadata": {}, @@ -5368,40 +2616,40 @@ " \n", " \n", " \n", - " 15\n", - " 0.005539\n", - " 0.015925\n", - " 0.103254\n", + " 1\n", + " 0.013617\n", + " 0.007498\n", + " 0.009371\n", " \n", " \n", - " 12\n", - " 0.004060\n", - " 0.056263\n", - " 0.219555\n", + " 3\n", + " 0.007820\n", + " 0.006761\n", + " 0.063202\n", " \n", " \n", - " 9\n", - " 0.003457\n", - " 0.060783\n", - " 0.097549\n", + " 6\n", + " 0.008831\n", + " 0.018997\n", + " 0.112792\n", " \n", " \n", - " 6\n", - " 0.005074\n", - " 0.031459\n", - " 0.076632\n", + " 9\n", + " 0.006202\n", + " 0.057011\n", + " 0.156331\n", " \n", " \n", - " 3\n", - " 0.002054\n", - " 0.002882\n", - " 0.046833\n", + " 12\n", + " 0.004002\n", + " 0.035374\n", + " 0.135507\n", " \n", " \n", - " 1\n", - " 0.017581\n", - " 0.001423\n", - " 0.012396\n", + " 15\n", + " 0.004925\n", + " 0.016791\n", + " 0.087433\n", " \n", " \n", "\n", @@ -5409,12 +2657,12 @@ ], "text/plain": [ " TKAN GRU LSTM\n", - "15 0.005539 0.015925 0.103254\n", - "12 0.004060 0.056263 0.219555\n", - "9 0.003457 0.060783 0.097549\n", - "6 0.005074 0.031459 0.076632\n", - "3 0.002054 0.002882 0.046833\n", - "1 0.017581 0.001423 0.012396" + "1 0.013617 0.007498 0.009371\n", + "3 0.007820 0.006761 0.063202\n", + "6 0.008831 0.018997 0.112792\n", + "9 0.006202 0.057011 0.156331\n", + "12 0.004002 0.035374 0.135507\n", + "15 0.004925 0.016791 0.087433" ] }, "metadata": {}, @@ -5448,40 +2696,40 @@ " \n", " \n", " \n", - " 15\n", - " 0.000215\n", - " 0.000578\n", - " 0.003284\n", + " 1\n", + " 0.000581\n", + " 0.000340\n", + " 0.000421\n", " \n", " \n", - " 12\n", - " 0.000148\n", - " 0.001978\n", - " 0.006539\n", + " 3\n", + " 0.000316\n", + " 0.000273\n", + " 0.002282\n", " \n", " \n", - " 9\n", - " 0.000131\n", - " 0.002142\n", - " 0.003162\n", + " 6\n", + " 0.000347\n", + " 0.000707\n", + " 0.003831\n", " \n", " \n", - " 6\n", - " 0.000199\n", - " 0.001151\n", - " 0.002695\n", + " 9\n", + " 0.000237\n", + " 0.002033\n", + " 0.004908\n", " \n", " \n", - " 3\n", - " 0.000094\n", - " 0.000112\n", - " 0.001722\n", + " 12\n", + " 0.000157\n", + " 0.001278\n", + " 0.004174\n", " \n", " \n", - " 1\n", - " 0.000747\n", - " 0.000065\n", - " 0.000558\n", + " 15\n", + " 0.000187\n", + " 0.000609\n", + " 0.002844\n", " \n", " \n", "\n", @@ -5489,12 +2737,12 @@ ], "text/plain": [ " TKAN GRU LSTM\n", - "15 0.000215 0.000578 0.003284\n", - "12 0.000148 0.001978 0.006539\n", - "9 0.000131 0.002142 0.003162\n", - "6 0.000199 0.001151 0.002695\n", - "3 0.000094 0.000112 0.001722\n", - "1 0.000747 0.000065 0.000558" + "1 0.000581 0.000340 0.000421\n", + "3 0.000316 0.000273 0.002282\n", + "6 0.000347 0.000707 0.003831\n", + "9 0.000237 0.002033 0.004908\n", + "12 0.000157 0.001278 0.004174\n", + "15 0.000187 0.000609 0.002844" ] }, "metadata": {}, @@ -5535,40 +2783,40 @@ " \n", " \n", " \n", - " 15\n", - " 148.763068\n", - " 42.478296\n", - " 33.661120\n", + " 1\n", + " 76.861277\n", + " 26.988356\n", + " 18.895929\n", " \n", " \n", - " 12\n", - " 131.653231\n", - " 41.333987\n", - " 31.695352\n", + " 3\n", + " 78.182759\n", + " 28.601693\n", + " 20.284139\n", " \n", " \n", - " 9\n", - " 112.193632\n", - " 36.144233\n", - " 24.899039\n", + " 6\n", + " 77.580002\n", + " 28.549838\n", + " 19.617023\n", " \n", " \n", - " 6\n", - " 97.351477\n", - " 39.048632\n", - " 26.407373\n", + " 9\n", + " 84.431719\n", + " 28.048108\n", + " 19.591124\n", " \n", " \n", - " 3\n", - " 88.306836\n", - " 34.379492\n", - " 26.078834\n", + " 12\n", + " 107.920725\n", + " 31.686749\n", + " 22.577694\n", " \n", " \n", - " 1\n", - " 95.677227\n", - " 36.256652\n", - " 24.907640\n", + " 15\n", + " 126.935536\n", + " 35.467582\n", + " 27.794554\n", " \n", " \n", "\n", @@ -5576,12 +2824,12 @@ ], "text/plain": [ " TKAN GRU LSTM\n", - "15 148.763068 42.478296 33.661120\n", - "12 131.653231 41.333987 31.695352\n", - "9 112.193632 36.144233 24.899039\n", - "6 97.351477 39.048632 26.407373\n", - "3 88.306836 34.379492 26.078834\n", - "1 95.677227 36.256652 24.907640" + "1 76.861277 26.988356 18.895929\n", + "3 78.182759 28.601693 20.284139\n", + "6 77.580002 28.549838 19.617023\n", + "9 84.431719 28.048108 19.591124\n", + "12 107.920725 31.686749 22.577694\n", + "15 126.935536 35.467582 27.794554" ] }, "metadata": {}, @@ -5615,40 +2863,40 @@ " \n", " \n", " \n", - " 15\n", - " 14.674705\n", - " 2.705071\n", - " 1.335022\n", + " 1\n", + " 11.082262\n", + " 3.945290\n", + " 1.723379\n", " \n", " \n", - " 12\n", - " 6.613783\n", - " 4.486661\n", - " 3.014970\n", + " 3\n", + " 10.159973\n", + " 3.052013\n", + " 1.950043\n", " \n", " \n", - " 9\n", - " 16.354742\n", - " 4.689843\n", - " 1.760482\n", + " 6\n", + " 8.965218\n", + " 1.885116\n", + " 2.320250\n", " \n", " \n", - " 6\n", - " 13.005265\n", - " 3.473911\n", - " 3.841358\n", + " 9\n", + " 14.904037\n", + " 4.894200\n", + " 0.870263\n", " \n", " \n", - " 3\n", - " 8.117994\n", - " 2.375397\n", - " 1.118862\n", + " 12\n", + " 14.096107\n", + " 3.808997\n", + " 1.490844\n", " \n", " \n", - " 1\n", - " 19.594882\n", - " 5.669121\n", - " 1.037579\n", + " 15\n", + " 18.705339\n", + " 3.458216\n", + " 2.267511\n", " \n", " \n", "\n", @@ -5656,12 +2904,12 @@ ], "text/plain": [ " TKAN GRU LSTM\n", - "15 14.674705 2.705071 1.335022\n", - "12 6.613783 4.486661 3.014970\n", - "9 16.354742 4.689843 1.760482\n", - "6 13.005265 3.473911 3.841358\n", - "3 8.117994 2.375397 1.118862\n", - "1 19.594882 5.669121 1.037579" + "1 11.082262 3.945290 1.723379\n", + "3 10.159973 3.052013 1.950043\n", + "6 8.965218 1.885116 2.320250\n", + "9 14.904037 4.894200 0.870263\n", + "12 14.096107 3.808997 1.490844\n", + "15 18.705339 3.458216 2.267511" ] }, "metadata": {}, @@ -5669,7 +2917,7 @@ } ], "source": [ - "n_aheads = [1, 3, 6, 9, 12, 15][::-1]\n", + "n_aheads = [1, 3, 6, 9, 12, 15]\n", "models = [\n", " \"TKAN\",\n", " \"GRU\",\n", @@ -5685,12 +2933,12 @@ " \n", " for model_id in models:\n", " \n", - " for run in range(5):\n", + " for run in range(10):\n", "\n", " if model_id == 'TKAN':\n", " model = Sequential([\n", " Input(shape=X_train.shape[1:]),\n", - " TKAN(100, sub_kan_output_dim = 20, sub_kan_input_dim = 20, return_sequences=True),\n", + " TKAN(100, return_sequences=True),\n", " TKAN(100, sub_kan_output_dim = 20, sub_kan_input_dim = 20, return_sequences=False),\n", " Dense(units=n_ahead, activation='linear')\n", " ], name = model_id)\n", @@ -5712,7 +2960,7 @@ " raise ValueError\n", " \n", " optimizer = keras.optimizers.Adam(0.001)\n", - " model.compile(optimizer=optimizer, loss='mean_squared_error')\n", + " model.compile(optimizer=optimizer, loss='mean_squared_error', jit_compile=True)\n", " if run==0:\n", " model.summary()\n", " \n", @@ -5732,26 +2980,18 @@ " del model\n", " del optimizer\n", " \n", - " \n", - " print('R2 scores')\n", - " print('Means:')\n", - " display(pd.DataFrame({model_id: {n_ahead: np.mean(results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results.keys()}))\n", - " display(pd.DataFrame({model_id: {n_ahead: np.mean(results_rmse[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results_rmse.keys()}))\n", - " print('Std:')\n", - " display(pd.DataFrame({model_id: {n_ahead: np.std(results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results.keys()}))\n", - " display(pd.DataFrame({model_id: {n_ahead: np.std(results_rmse[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results_rmse.keys()}))\n", - " print('Training Times')\n", - " display(pd.DataFrame({model_id: {n_ahead: np.mean(time_results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in time_results.keys()}))\n", - " display(pd.DataFrame({model_id: {n_ahead: np.std(time_results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in time_results.keys()}))" + "\n", + "print('R2 scores')\n", + "print('Means:')\n", + "display(pd.DataFrame({model_id: {n_ahead: np.mean(results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results.keys()}))\n", + "display(pd.DataFrame({model_id: {n_ahead: np.mean(results_rmse[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results_rmse.keys()}))\n", + "print('Std:')\n", + "display(pd.DataFrame({model_id: {n_ahead: np.std(results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results.keys()}))\n", + "display(pd.DataFrame({model_id: {n_ahead: np.std(results_rmse[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results_rmse.keys()}))\n", + "print('Training Times')\n", + "display(pd.DataFrame({model_id: {n_ahead: np.mean(time_results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in time_results.keys()}))\n", + "display(pd.DataFrame({model_id: {n_ahead: np.std(time_results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in time_results.keys()}))" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8ffccbbf-8ca4-4166-97bc-24c0f5ca02d0", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -5770,7 +3010,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.11.9" } }, "nbformat": 4,