diff --git a/README.md b/README.md
index 103df83..9189f4d 100644
--- a/README.md
+++ b/README.md
@@ -6,6 +6,8 @@ The implementation is tested to be compatatible with Tensorflow, Jax and Torch.
It is the original implementation of the [paper](https://arxiv.org/abs/2405.07344)
The KAN part implementation has been inspired from [efficient_kan](https://github.com/Blealtan/efficient-kan), and is available [here](https://github.com/remigenet/keras_efficient_kan) and works similarly to it, thus not exactly like the [original implementation](https://github.com/KindXiaoming/pykan).
+In case of performance consideration, the best setup tested used [jax docker image](https://hub.docker.com/r/bitnami/jax/) followed by installing jax using ```pip install "jax[cuda12]"```, this is what is used in the example section where you can compare the TKAN vs LSTM vs GRU time and performance.
+I also discourage using as is the example for torch, it seems that currently when running test using torch backend with keras is much slower than torch directly, even for GRU or LSTM.
![TKAN representation](image/TKAN.drawio.png)
diff --git a/examples/simple_example_tkan.ipynb b/examples/simple_example_tkan.ipynb
index b75576d..7ccb161 100644
--- a/examples/simple_example_tkan.ipynb
+++ b/examples/simple_example_tkan.ipynb
@@ -5,113 +5,88 @@
"id": "bb186818-bd1d-46ed-a018-27efa013b206",
"metadata": {},
"source": [
- "# "
+ "# TKAN example and comparison with benchmarks\n",
+ "\n",
+ "All test have been run on a RTX 4070 with an Core™ i7-6700K on vast.ai using this [jax docker image](https://hub.docker.com/r/bitnami/jax/)\n",
+ "\n",
+ "tkan version: 0.4.1"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "bc3f1ac2-1785-4e08-89a5-0bc5e31ce2b7",
- "metadata": {},
+ "metadata": {
+ "scrolled": true
+ },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (2.2.2)\n",
- "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (1.26.4)\n",
- "Requirement already satisfied: matplotlib in /usr/local/lib/python3.11/dist-packages (3.9.1)\n",
- "Requirement already satisfied: pyarrow in /usr/local/lib/python3.11/dist-packages (17.0.0)\n",
- "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (1.5.1)\n",
- "Requirement already satisfied: tkan in /usr/local/lib/python3.11/dist-packages (0.4.0)\n",
- "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas) (2.9.0.post0)\n",
- "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas) (2024.1)\n",
- "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas) (2024.1)\n",
- "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (1.2.1)\n",
- "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (0.12.1)\n",
- "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (4.53.1)\n",
- "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (1.4.5)\n",
- "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (24.1)\n",
- "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (10.4.0)\n",
- "Requirement already satisfied: pyparsing>=2.3.1 in /usr/lib/python3/dist-packages (from matplotlib) (2.4.7)\n",
- "Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.14.0)\n",
- "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (1.4.2)\n",
- "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (3.5.0)\n",
- "Requirement already satisfied: keras<4.0,>=3.0.0 in /usr/local/lib/python3.11/dist-packages (from tkan) (3.4.1)\n",
- "Requirement already satisfied: keras_efficient_kan<0.2.0,>=0.1.4 in /usr/local/lib/python3.11/dist-packages (from tkan) (0.1.4)\n",
- "Requirement already satisfied: absl-py in /usr/local/lib/python3.11/dist-packages (from keras<4.0,>=3.0.0->tkan) (2.1.0)\n",
- "Requirement already satisfied: rich in /usr/local/lib/python3.11/dist-packages (from keras<4.0,>=3.0.0->tkan) (13.7.1)\n",
- "Requirement already satisfied: namex in /usr/local/lib/python3.11/dist-packages (from keras<4.0,>=3.0.0->tkan) (0.0.8)\n",
- "Requirement already satisfied: h5py in /usr/local/lib/python3.11/dist-packages (from keras<4.0,>=3.0.0->tkan) (3.11.0)\n",
- "Requirement already satisfied: optree in /usr/local/lib/python3.11/dist-packages (from keras<4.0,>=3.0.0->tkan) (0.12.1)\n",
- "Requirement already satisfied: ml-dtypes in /usr/local/lib/python3.11/dist-packages (from keras<4.0,>=3.0.0->tkan) (0.4.0)\n",
- "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
- "Requirement already satisfied: typing-extensions>=4.5.0 in /usr/local/lib/python3.11/dist-packages (from optree->keras<4.0,>=3.0.0->tkan) (4.12.2)\n",
- "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich->keras<4.0,>=3.0.0->tkan) (3.0.0)\n",
- "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich->keras<4.0,>=3.0.0->tkan) (2.18.0)\n",
- "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich->keras<4.0,>=3.0.0->tkan) (0.1.2)\n",
+ "\u001b[33mDEPRECATION: Loading egg at /opt/bitnami/python/lib/python3.11/site-packages/pip-23.3.2-py3.11.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n",
+ "\u001b[0mRequirement already satisfied: pandas in /opt/bitnami/python/lib/python3.11/site-packages (2.2.2)\n",
+ "Requirement already satisfied: numpy in /opt/bitnami/python/lib/python3.11/site-packages (1.26.4)\n",
+ "Requirement already satisfied: matplotlib in /opt/bitnami/python/lib/python3.11/site-packages (3.9.1)\n",
+ "Requirement already satisfied: pyarrow in /opt/bitnami/python/lib/python3.11/site-packages (17.0.0)\n",
+ "Requirement already satisfied: scikit-learn in /opt/bitnami/python/lib/python3.11/site-packages (1.5.1)\n",
+ "Requirement already satisfied: tkan in /opt/bitnami/python/lib/python3.11/site-packages (0.4.1)\n",
+ "Requirement already satisfied: jax[cuda12] in /opt/bitnami/python/lib/python3.11/site-packages (0.4.30.dev20240618+f4158ac)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/bitnami/python/lib/python3.11/site-packages (from pandas) (2.9.0.post0)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /opt/bitnami/python/lib/python3.11/site-packages (from pandas) (2024.1)\n",
+ "Requirement already satisfied: tzdata>=2022.7 in /opt/bitnami/python/lib/python3.11/site-packages (from pandas) (2024.1)\n",
+ "Requirement already satisfied: contourpy>=1.0.1 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (1.2.1)\n",
+ "Requirement already satisfied: cycler>=0.10 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (0.12.1)\n",
+ "Requirement already satisfied: fonttools>=4.22.0 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (4.53.1)\n",
+ "Requirement already satisfied: kiwisolver>=1.3.1 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (1.4.5)\n",
+ "Requirement already satisfied: packaging>=20.0 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (24.1)\n",
+ "Requirement already satisfied: pillow>=8 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (10.4.0)\n",
+ "Requirement already satisfied: pyparsing>=2.3.1 in /opt/bitnami/python/lib/python3.11/site-packages (from matplotlib) (3.1.2)\n",
+ "Requirement already satisfied: scipy>=1.6.0 in /opt/bitnami/python/lib/python3.11/site-packages (from scikit-learn) (1.14.0)\n",
+ "Requirement already satisfied: joblib>=1.2.0 in /opt/bitnami/python/lib/python3.11/site-packages (from scikit-learn) (1.4.2)\n",
+ "Requirement already satisfied: threadpoolctl>=3.1.0 in /opt/bitnami/python/lib/python3.11/site-packages (from scikit-learn) (3.5.0)\n",
+ "Requirement already satisfied: keras<4.0,>=3.0.0 in /opt/bitnami/python/lib/python3.11/site-packages (from tkan) (3.4.1)\n",
+ "Requirement already satisfied: keras_efficient_kan<0.2.0,>=0.1.4 in /opt/bitnami/python/lib/python3.11/site-packages (from tkan) (0.1.4)\n",
+ "Requirement already satisfied: jaxlib<=0.4.30,>=0.4.27 in /opt/bitnami/python/lib/python3.11/site-packages (from jax[cuda12]) (0.4.30)\n",
+ "Requirement already satisfied: ml-dtypes>=0.2.0 in /opt/bitnami/python/lib/python3.11/site-packages (from jax[cuda12]) (0.4.0)\n",
+ "Requirement already satisfied: opt-einsum in /opt/bitnami/python/lib/python3.11/site-packages (from jax[cuda12]) (3.3.0)\n",
+ "Requirement already satisfied: jax-cuda12-plugin<=0.4.30,>=0.4.30 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (0.4.30)\n",
+ "Requirement already satisfied: jax-cuda12-pjrt==0.4.30 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin<=0.4.30,>=0.4.30->jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (0.4.30)\n",
+ "Requirement already satisfied: nvidia-cublas-cu12>=12.1.3.1 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.3.2)\n",
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12>=12.1.105 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n",
+ "Requirement already satisfied: nvidia-cuda-nvcc-cu12>=12.1.105 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n",
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12>=12.1.105 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n",
+ "Requirement already satisfied: nvidia-cudnn-cu12<10.0,>=9.0 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (9.2.1.18)\n",
+ "Requirement already satisfied: nvidia-cufft-cu12>=11.0.2.54 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (11.2.3.61)\n",
+ "Requirement already satisfied: nvidia-cusolver-cu12>=11.4.5.107 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (11.6.3.83)\n",
+ "Requirement already satisfied: nvidia-cusparse-cu12>=12.1.0.106 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.1.3)\n",
+ "Requirement already satisfied: nvidia-nccl-cu12>=2.18.1 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (2.22.3)\n",
+ "Requirement already satisfied: nvidia-nvjitlink-cu12>=12.1.105 in /opt/bitnami/python/lib/python3.11/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n",
+ "Requirement already satisfied: absl-py in /opt/bitnami/python/lib/python3.11/site-packages (from keras<4.0,>=3.0.0->tkan) (2.1.0)\n",
+ "Requirement already satisfied: rich in /opt/bitnami/python/lib/python3.11/site-packages (from keras<4.0,>=3.0.0->tkan) (13.7.1)\n",
+ "Requirement already satisfied: namex in /opt/bitnami/python/lib/python3.11/site-packages (from keras<4.0,>=3.0.0->tkan) (0.0.8)\n",
+ "Requirement already satisfied: h5py in /opt/bitnami/python/lib/python3.11/site-packages (from keras<4.0,>=3.0.0->tkan) (3.11.0)\n",
+ "Requirement already satisfied: optree in /opt/bitnami/python/lib/python3.11/site-packages (from keras<4.0,>=3.0.0->tkan) (0.12.1)\n",
+ "Requirement already satisfied: six>=1.5 in /opt/bitnami/python/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
+ "Requirement already satisfied: typing-extensions>=4.5.0 in /opt/bitnami/python/lib/python3.11/site-packages (from optree->keras<4.0,>=3.0.0->tkan) (4.12.2)\n",
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/bitnami/python/lib/python3.11/site-packages (from rich->keras<4.0,>=3.0.0->tkan) (3.0.0)\n",
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/bitnami/python/lib/python3.11/site-packages (from rich->keras<4.0,>=3.0.0->tkan) (2.18.0)\n",
+ "Requirement already satisfied: mdurl~=0.1 in /opt/bitnami/python/lib/python3.11/site-packages (from markdown-it-py>=2.2.0->rich->keras<4.0,>=3.0.0->tkan) (0.1.2)\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
],
"source": [
- "!pip install pandas numpy matplotlib pyarrow scikit-learn tkan "
+ "!pip install pandas numpy matplotlib pyarrow scikit-learn tkan \"jax[cuda12]\""
]
},
{
"cell_type": "code",
"execution_count": 2,
- "id": "4bf2cb3c-56e6-49a6-91b8-47e616b850dc",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Requirement already satisfied: jax[cuda12] in /usr/local/lib/python3.11/dist-packages (0.4.30)\n",
- "Requirement already satisfied: jaxlib<=0.4.30,>=0.4.27 in /usr/local/lib/python3.11/dist-packages (from jax[cuda12]) (0.4.30)\n",
- "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from jax[cuda12]) (0.4.0)\n",
- "Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.11/dist-packages (from jax[cuda12]) (1.26.4)\n",
- "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.11/dist-packages (from jax[cuda12]) (3.3.0)\n",
- "Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.11/dist-packages (from jax[cuda12]) (1.14.0)\n",
- "Requirement already satisfied: jax-cuda12-plugin<=0.4.30,>=0.4.30 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (0.4.30)\n",
- "Requirement already satisfied: jax-cuda12-pjrt==0.4.30 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin<=0.4.30,>=0.4.30->jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (0.4.30)\n",
- "Requirement already satisfied: nvidia-cublas-cu12>=12.1.3.1 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.3.2)\n",
- "Requirement already satisfied: nvidia-cuda-cupti-cu12>=12.1.105 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n",
- "Requirement already satisfied: nvidia-cuda-nvcc-cu12>=12.1.105 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n",
- "Requirement already satisfied: nvidia-cuda-runtime-cu12>=12.1.105 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n",
- "Requirement already satisfied: nvidia-cudnn-cu12<10.0,>=9.0 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (9.2.1.18)\n",
- "Requirement already satisfied: nvidia-cufft-cu12>=11.0.2.54 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (11.2.3.61)\n",
- "Requirement already satisfied: nvidia-cusolver-cu12>=11.4.5.107 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (11.6.3.83)\n",
- "Requirement already satisfied: nvidia-cusparse-cu12>=12.1.0.106 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.1.3)\n",
- "Requirement already satisfied: nvidia-nccl-cu12>=2.18.1 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (2.22.3)\n",
- "Requirement already satisfied: nvidia-nvjitlink-cu12>=12.1.105 in /usr/local/lib/python3.11/dist-packages (from jax-cuda12-plugin[with_cuda]<=0.4.30,>=0.4.30; extra == \"cuda12\"->jax[cuda12]) (12.5.82)\n",
- "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
- "\u001b[0m"
- ]
- }
- ],
- "source": [
- "!pip install -U \"jax[cuda12]\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
"id": "34213122-fabb-4d55-a918-a337be21b974",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2024-07-23 18:15:27.842138: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
- "2024-07-23 18:15:27.861965: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
- "2024-07-23 18:15:27.867851: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"import os\n",
"BACKEND = 'jax' # You can use any backend here \n",
@@ -166,7 +141,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 3,
"id": "735c5c88-924a-426e-b10f-a05e7b0556f2",
"metadata": {},
"outputs": [
@@ -571,7 +546,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"id": "34c9f03b-11bf-4c88-abcb-5fc8e43c6008",
"metadata": {},
"outputs": [],
@@ -694,17 +669,10 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 5,
"id": "27b4f64a-08da-4cfc-97c9-1d5145887c23",
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2024-07-23 18:15:35.115393: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.\n"
- ]
- },
{
"data": {
"text/html": [
@@ -724,11 +692,11 @@
"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ tkan (TKAN) │ (None, 75, 100) │ 41,750 │\n",
+ "│ tkan (TKAN) │ (None, 45, 100) │ 41,316 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ tkan_1 (TKAN) │ (None, 100) │ 67,670 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense (Dense) │ (None, 15) │ 1,515 │\n",
+ "│ dense (Dense) │ (None, 1) │ 101 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"
\n"
],
@@ -736,11 +704,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ tkan (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m75\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,750\u001b[0m │\n",
+ "│ tkan (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,316\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ tkan_1 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m) │ \u001b[38;5;34m1,515\u001b[0m │\n",
+ "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m101\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
@@ -750,11 +718,11 @@
{
"data": {
"text/html": [
- " Total params: 110,935 (433.34 KB)\n",
+ " Total params: 109,087 (426.12 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m110,935\u001b[0m (433.34 KB)\n"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m109,087\u001b[0m (426.12 KB)\n"
]
},
"metadata": {},
@@ -763,11 +731,11 @@
{
"data": {
"text/html": [
- " Trainable params: 110,915 (433.26 KB)\n",
+ " Trainable params: 109,067 (426.04 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m110,915\u001b[0m (433.26 KB)\n"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m109,067\u001b[0m (426.04 KB)\n"
]
},
"metadata": {},
@@ -790,11 +758,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "124.26042032241821 0.09158093535460268\n",
- "162.43516397476196 0.10293650922567563\n",
- "165.25300359725952 0.09519336968632934\n",
- "146.79839873313904 0.08672428978473426\n",
- "145.06835103034973 0.09001762052527772\n"
+ "87.7924542427063 0.2883694212812108\n",
+ "58.81150460243225 0.2996397661396135\n",
+ "93.38526654243469 0.3341084410571994\n",
+ "69.46869540214539 0.31051983221133617\n",
+ "68.03620338439941 0.2921106524335578\n",
+ "69.86754989624023 0.31632189994316484\n",
+ "83.32379245758057 0.3110331089568017\n",
+ "87.04948496818542 0.32794331666897736\n",
+ "84.89363670349121 0.31305216100227273\n",
+ "65.98418259620667 0.30768323851050017\n"
]
},
{
@@ -816,11 +789,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ gru (GRU) │ (None, 75, 100) │ 36,300 │\n",
+ "│ gru (GRU) │ (None, 45, 100) │ 36,300 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ gru_1 (GRU) │ (None, 100) │ 60,600 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_5 (Dense) │ (None, 15) │ 1,515 │\n",
+ "│ dense_10 (Dense) │ (None, 1) │ 101 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"
\n"
],
@@ -828,11 +801,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ gru (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m75\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n",
+ "│ gru (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ gru_1 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_5 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m) │ \u001b[38;5;34m1,515\u001b[0m │\n",
+ "│ dense_10 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m101\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
@@ -842,11 +815,11 @@
{
"data": {
"text/html": [
- " Total params: 98,415 (384.43 KB)\n",
+ " Total params: 97,001 (378.91 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m98,415\u001b[0m (384.43 KB)\n"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,001\u001b[0m (378.91 KB)\n"
]
},
"metadata": {},
@@ -855,11 +828,11 @@
{
"data": {
"text/html": [
- " Trainable params: 98,415 (384.43 KB)\n",
+ " Trainable params: 97,001 (378.91 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m98,415\u001b[0m (384.43 KB)\n"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,001\u001b[0m (378.91 KB)\n"
]
},
"metadata": {},
@@ -882,11 +855,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "40.458553075790405 0.09533817072522596\n",
- "41.53169059753418 0.06848331293892884\n",
- "41.19035625457764 0.08615624519190934\n",
- "47.83834481239319 0.05216959504533443\n",
- "41.37253522872925 0.09000230869662042\n"
+ "25.800300121307373 0.39845659417931767\n",
+ "33.090811014175415 0.3765754299721896\n",
+ "29.157642126083374 0.3996904390926539\n",
+ "20.830080032348633 0.39383111789434266\n",
+ "24.594220638275146 0.39854787088819454\n",
+ "25.829734086990356 0.3980608345819734\n",
+ "20.8519549369812 0.39818872394114213\n",
+ "29.55962371826172 0.39876597794912183\n",
+ "28.695900917053223 0.4014753210019817\n",
+ "31.473294496536255 0.385033271890177\n"
]
},
{
@@ -908,11 +886,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ lstm (LSTM) │ (None, 75, 100) │ 48,000 │\n",
+ "│ lstm (LSTM) │ (None, 45, 100) │ 48,000 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ lstm_1 (LSTM) │ (None, 100) │ 80,400 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_10 (Dense) │ (None, 15) │ 1,515 │\n",
+ "│ dense_20 (Dense) │ (None, 1) │ 101 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"
\n"
],
@@ -920,11 +898,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ lstm (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m75\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n",
+ "│ lstm (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
"│ lstm_1 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_10 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m) │ \u001b[38;5;34m1,515\u001b[0m │\n",
+ "│ dense_20 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m101\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
@@ -934,11 +912,11 @@
{
"data": {
"text/html": [
- " Total params: 129,915 (507.48 KB)\n",
+ " Total params: 128,501 (501.96 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,915\u001b[0m (507.48 KB)\n"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m128,501\u001b[0m (501.96 KB)\n"
]
},
"metadata": {},
@@ -947,11 +925,11 @@
{
"data": {
"text/html": [
- " Trainable params: 129,915 (507.48 KB)\n",
+ " Trainable params: 128,501 (501.96 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,915\u001b[0m (507.48 KB)\n"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m128,501\u001b[0m (501.96 KB)\n"
]
},
"metadata": {},
@@ -974,583 +952,220 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "34.312363147735596 -0.19760337708902084\n",
- "31.26186227798462 -0.17731950497927035\n",
- "34.37222385406494 -0.37332071192906224\n",
- "35.09816241264343 -0.08747675367134147\n",
- "33.26098704338074 -0.3229705947886879\n",
- "R2 scores\n",
- "Means:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
+ "17.663942098617554 0.3868841900601451\n",
+ "18.93087387084961 0.36429593525523807\n",
+ "17.406195163726807 0.3900695412878702\n",
+ "20.44808530807495 0.3786097961170115\n",
+ "16.755755186080933 0.392573888221295\n",
+ "20.213451862335205 0.3840845450430137\n",
+ "22.768882513046265 0.36630601832303555\n",
+ "18.173134088516235 0.3832138605836025\n",
+ "17.591335773468018 0.3918248620808171\n",
+ "19.007630586624146 0.37993474576853103\n"
]
},
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.093291 | \n",
- " 0.07843 | \n",
- " -0.231738 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "Model: \"TKAN\"\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.093291 0.07843 -0.231738\n",
- "12 NaN NaN NaN\n",
- "9 NaN NaN NaN\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
+ "\u001b[1mModel: \"TKAN\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.067705 | \n",
- " 0.068194 | \n",
- " 0.078686 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+ "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+ "│ tkan_20 (TKAN) │ (None, 45, 100) │ 41,316 │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ tkan_21 (TKAN) │ (None, 100) │ 67,670 │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ dense_30 (Dense) │ (None, 3) │ 303 │\n",
+ "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.067705 0.068194 0.078686\n",
- "12 NaN NaN NaN\n",
- "9 NaN NaN NaN\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+ "│ tkan_20 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,316\u001b[0m │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ tkan_21 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ dense_30 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m303\u001b[0m │\n",
+ "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Std:\n"
- ]
+ "data": {
+ "text/html": [
+ " Total params: 109,289 (426.91 KB)\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m109,289\u001b[0m (426.91 KB)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
+ "data": {
+ "text/html": [
+ " Trainable params: 109,269 (426.83 KB)\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m109,269\u001b[0m (426.83 KB)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.005539 | \n",
- " 0.015925 | \n",
- " 0.103254 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ " Non-trainable params: 20 (80.00 B)\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.005539 0.015925 0.103254\n",
- "12 NaN NaN NaN\n",
- "9 NaN NaN NaN\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
- "name": "stderr",
+ "name": "stdout",
"output_type": "stream",
"text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
+ "96.62368559837341 0.2039716435474497\n",
+ "74.19592547416687 0.1932276162314687\n",
+ "79.30283999443054 0.186078923268888\n",
+ "67.78045701980591 0.1888747778655048\n",
+ "90.85375475883484 0.20038166436812602\n",
+ "83.50374150276184 0.198828800729096\n",
+ "79.80795526504517 0.1848691783440326\n",
+ "59.527796268463135 0.18538958325425756\n",
+ "73.05465388298035 0.17813436933288326\n",
+ "77.17678189277649 0.185736715040992\n"
]
},
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.000215 | \n",
- " 0.000578 | \n",
- " 0.003284 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "Model: \"GRU\"\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.000215 0.000578 0.003284\n",
- "12 NaN NaN NaN\n",
- "9 NaN NaN NaN\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
+ "\u001b[1mModel: \"GRU\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Training Times\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
+ "data": {
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+ "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+ "│ gru_20 (GRU) │ (None, 45, 100) │ 36,300 │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ gru_21 (GRU) │ (None, 100) │ 60,600 │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ dense_40 (Dense) │ (None, 3) │ 303 │\n",
+ "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+ "│ gru_20 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ gru_21 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ dense_40 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m303\u001b[0m │\n",
+ "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 148.763068 | \n",
- " 42.478296 | \n",
- " 33.66112 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ " Total params: 97,203 (379.70 KB)\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 148.763068 42.478296 33.66112\n",
- "12 NaN NaN NaN\n",
- "9 NaN NaN NaN\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,203\u001b[0m (379.70 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
+ "data": {
+ "text/html": [
+ " Trainable params: 97,203 (379.70 KB)\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,203\u001b[0m (379.70 KB)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 14.674705 | \n",
- " 2.705071 | \n",
- " 1.335022 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ " Non-trainable params: 0 (0.00 B)\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 14.674705 2.705071 1.335022\n",
- "12 NaN NaN NaN\n",
- "9 NaN NaN NaN\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "32.23410701751709 0.24445524304420677\n",
+ "34.76943516731262 0.2201955661466759\n",
+ "25.166553735733032 0.23882125234574236\n",
+ "28.084814310073853 0.2344194906558783\n",
+ "27.79819416999817 0.235279604982632\n",
+ "31.362367868423462 0.23122380522590277\n",
+ "24.73177719116211 0.2426694875126042\n",
+ "26.578373432159424 0.23963082071966313\n",
+ "27.062561988830566 0.24032075216240234\n",
+ "28.228745222091675 0.24178121308010916\n"
+ ]
+ },
{
"data": {
"text/html": [
- "Model: \"TKAN\"\n",
+ "Model: \"LSTM\"\n",
"
\n"
],
"text/plain": [
- "\u001b[1mModel: \"TKAN\"\u001b[0m\n"
+ "\u001b[1mModel: \"LSTM\"\u001b[0m\n"
]
},
"metadata": {},
@@ -1562,11 +1177,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ tkan_10 (TKAN) │ (None, 60, 100) │ 41,750 │\n",
+ "│ lstm_20 (LSTM) │ (None, 45, 100) │ 48,000 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ tkan_11 (TKAN) │ (None, 100) │ 67,670 │\n",
+ "│ lstm_21 (LSTM) │ (None, 100) │ 80,400 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_15 (Dense) │ (None, 12) │ 1,212 │\n",
+ "│ dense_50 (Dense) │ (None, 3) │ 303 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"
\n"
],
@@ -1574,11 +1189,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ tkan_10 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m60\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,750\u001b[0m │\n",
+ "│ lstm_20 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ tkan_11 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n",
+ "│ lstm_21 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_15 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,212\u001b[0m │\n",
+ "│ dense_50 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m303\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
@@ -1588,11 +1203,11 @@
{
"data": {
"text/html": [
- " Total params: 110,632 (432.16 KB)\n",
+ " Total params: 128,703 (502.75 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m110,632\u001b[0m (432.16 KB)\n"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m128,703\u001b[0m (502.75 KB)\n"
]
},
"metadata": {},
@@ -1601,11 +1216,11 @@
{
"data": {
"text/html": [
- " Trainable params: 110,612 (432.08 KB)\n",
+ " Trainable params: 128,703 (502.75 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m110,612\u001b[0m (432.08 KB)\n"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m128,703\u001b[0m (502.75 KB)\n"
]
},
"metadata": {},
@@ -1614,11 +1229,11 @@
{
"data": {
"text/html": [
- " Non-trainable params: 20 (80.00 B)\n",
+ " Non-trainable params: 0 (0.00 B)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n"
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
@@ -1628,21 +1243,26 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "131.22357177734375 0.09087923067578507\n",
- "141.15374612808228 0.10076559080540619\n",
- "121.34675312042236 0.09357512681452108\n",
- "129.00918769836426 0.10084381545148026\n",
- "135.5328950881958 0.09396223859149976\n"
+ "20.623400688171387 0.14482767971630928\n",
+ "21.055927991867065 0.07395821838488396\n",
+ "24.374985218048096 0.022805938949401933\n",
+ "21.713353157043457 0.0749404329796987\n",
+ "21.184755563735962 0.1522887184364525\n",
+ "18.015154600143433 0.16326070883053923\n",
+ "18.17797017097473 0.1804372350265144\n",
+ "17.993833541870117 0.18266789728362973\n",
+ "18.728025197982788 0.14342017896777437\n",
+ "20.97398853302002 -0.002893617498768264\n"
]
},
{
"data": {
"text/html": [
- "Model: \"GRU\"\n",
+ "Model: \"TKAN\"\n",
"
\n"
],
"text/plain": [
- "\u001b[1mModel: \"GRU\"\u001b[0m\n"
+ "\u001b[1mModel: \"TKAN\"\u001b[0m\n"
]
},
"metadata": {},
@@ -1654,11 +1274,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ gru_10 (GRU) │ (None, 60, 100) │ 36,300 │\n",
+ "│ tkan_40 (TKAN) │ (None, 45, 100) │ 41,316 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ gru_11 (GRU) │ (None, 100) │ 60,600 │\n",
+ "│ tkan_41 (TKAN) │ (None, 100) │ 67,670 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_20 (Dense) │ (None, 12) │ 1,212 │\n",
+ "│ dense_60 (Dense) │ (None, 6) │ 606 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"
\n"
],
@@ -1666,11 +1286,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ gru_10 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m60\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n",
+ "│ tkan_40 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,316\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ gru_11 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n",
+ "│ tkan_41 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_20 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,212\u001b[0m │\n",
+ "│ dense_60 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m606\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
@@ -1680,11 +1300,11 @@
{
"data": {
"text/html": [
- " Total params: 98,112 (383.25 KB)\n",
+ " Total params: 109,592 (428.09 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m98,112\u001b[0m (383.25 KB)\n"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m109,592\u001b[0m (428.09 KB)\n"
]
},
"metadata": {},
@@ -1693,11 +1313,11 @@
{
"data": {
"text/html": [
- " Trainable params: 98,112 (383.25 KB)\n",
+ " Trainable params: 109,572 (428.02 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m98,112\u001b[0m (383.25 KB)\n"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m109,572\u001b[0m (428.02 KB)\n"
]
},
"metadata": {},
@@ -1706,11 +1326,11 @@
{
"data": {
"text/html": [
- " Non-trainable params: 0 (0.00 B)\n",
+ " Non-trainable params: 20 (80.00 B)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n"
]
},
"metadata": {},
@@ -1720,21 +1340,26 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "38.70786952972412 0.08947185062429679\n",
- "35.21390247344971 0.10739726688893424\n",
- "48.478371143341064 -0.05339822582392689\n",
- "40.725359201431274 0.06710397964579173\n",
- "43.54443335533142 0.04774073080153538\n"
+ "65.0717043876648 0.12467616278770983\n",
+ "90.68930077552795 0.13929252527937125\n",
+ "76.80289697647095 0.12103943728777532\n",
+ "84.18350672721863 0.13730972361175564\n",
+ "59.469563484191895 0.10800808332018592\n",
+ "80.92852973937988 0.13700481460829453\n",
+ "85.58175778388977 0.1324956475675861\n",
+ "78.75041937828064 0.13073665118121205\n",
+ "80.78611373901367 0.12963646594644282\n",
+ "73.53622269630432 0.12723135400013533\n"
]
},
{
"data": {
"text/html": [
- "Model: \"LSTM\"\n",
+ "Model: \"GRU\"\n",
"
\n"
],
"text/plain": [
- "\u001b[1mModel: \"LSTM\"\u001b[0m\n"
+ "\u001b[1mModel: \"GRU\"\u001b[0m\n"
]
},
"metadata": {},
@@ -1746,11 +1371,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ lstm_10 (LSTM) │ (None, 60, 100) │ 48,000 │\n",
+ "│ gru_40 (GRU) │ (None, 45, 100) │ 36,300 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ lstm_11 (LSTM) │ (None, 100) │ 80,400 │\n",
+ "│ gru_41 (GRU) │ (None, 100) │ 60,600 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_25 (Dense) │ (None, 12) │ 1,212 │\n",
+ "│ dense_70 (Dense) │ (None, 6) │ 606 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"
\n"
],
@@ -1758,11 +1383,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ lstm_10 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m60\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n",
+ "│ gru_40 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ lstm_11 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n",
+ "│ gru_41 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_25 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,212\u001b[0m │\n",
+ "│ dense_70 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m606\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
@@ -1772,11 +1397,11 @@
{
"data": {
"text/html": [
- " Total params: 129,612 (506.30 KB)\n",
+ " Total params: 97,506 (380.88 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,612\u001b[0m (506.30 KB)\n"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,506\u001b[0m (380.88 KB)\n"
]
},
"metadata": {},
@@ -1785,11 +1410,11 @@
{
"data": {
"text/html": [
- " Trainable params: 129,612 (506.30 KB)\n",
+ " Trainable params: 97,506 (380.88 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,612\u001b[0m (506.30 KB)\n"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,506\u001b[0m (380.88 KB)\n"
]
},
"metadata": {},
@@ -1812,1595 +1437,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "28.626214027404785 -0.1006527298304624\n",
- "35.30322313308716 -0.24341912522819487\n",
- "27.737242698669434 -0.5102580838684908\n",
- "32.53866934776306 -0.7130992597534002\n",
- "34.27140927314758 -0.24619645358326023\n",
- "R2 scores\n",
- "Means:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.093291 | \n",
- " 0.078430 | \n",
- " -0.231738 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.096005 | \n",
- " 0.051663 | \n",
- " -0.362725 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.093291 0.078430 -0.231738\n",
- "12 0.096005 0.051663 -0.362725\n",
- "9 NaN NaN NaN\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.067705 | \n",
- " 0.068194 | \n",
- " 0.078686 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.067321 | \n",
- " 0.068831 | \n",
- " 0.082159 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.067705 0.068194 0.078686\n",
- "12 0.067321 0.068831 0.082159\n",
- "9 NaN NaN NaN\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Std:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.005539 | \n",
- " 0.015925 | \n",
- " 0.103254 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.004060 | \n",
- " 0.056263 | \n",
- " 0.219555 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.005539 0.015925 0.103254\n",
- "12 0.004060 0.056263 0.219555\n",
- "9 NaN NaN NaN\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.000215 | \n",
- " 0.000578 | \n",
- " 0.003284 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.000148 | \n",
- " 0.001978 | \n",
- " 0.006539 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.000215 0.000578 0.003284\n",
- "12 0.000148 0.001978 0.006539\n",
- "9 NaN NaN NaN\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Training Times\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 148.763068 | \n",
- " 42.478296 | \n",
- " 33.661120 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 131.653231 | \n",
- " 41.333987 | \n",
- " 31.695352 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 148.763068 42.478296 33.661120\n",
- "12 131.653231 41.333987 31.695352\n",
- "9 NaN NaN NaN\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 14.674705 | \n",
- " 2.705071 | \n",
- " 1.335022 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 6.613783 | \n",
- " 4.486661 | \n",
- " 3.014970 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 14.674705 2.705071 1.335022\n",
- "12 6.613783 4.486661 3.014970\n",
- "9 NaN NaN NaN\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "Model: \"TKAN\"\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1mModel: \"TKAN\"\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
- "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ tkan_20 (TKAN) │ (None, 45, 100) │ 41,750 │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ tkan_21 (TKAN) │ (None, 100) │ 67,670 │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_30 (Dense) │ (None, 9) │ 909 │\n",
- "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
- "
\n"
- ],
- "text/plain": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
- "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ tkan_20 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,750\u001b[0m │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ tkan_21 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_30 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m909\u001b[0m │\n",
- "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Total params: 110,329 (430.97 KB)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m110,329\u001b[0m (430.97 KB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Trainable params: 110,309 (430.89 KB)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m110,309\u001b[0m (430.89 KB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Non-trainable params: 20 (80.00 B)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "104.23471403121948 0.113891478849783\n",
- "89.38614058494568 0.10592790984288752\n",
- "106.43402981758118 0.10938055442300718\n",
- "135.6341907978058 0.11293714586336041\n",
- "125.27908515930176 0.11555530130229706\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "Model: \"GRU\"\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1mModel: \"GRU\"\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
- "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ gru_20 (GRU) │ (None, 45, 100) │ 36,300 │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ gru_21 (GRU) │ (None, 100) │ 60,600 │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_35 (Dense) │ (None, 9) │ 909 │\n",
- "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
- "
\n"
- ],
- "text/plain": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
- "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ gru_20 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ gru_21 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_35 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m909\u001b[0m │\n",
- "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Total params: 97,809 (382.07 KB)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,809\u001b[0m (382.07 KB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Trainable params: 97,809 (382.07 KB)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,809\u001b[0m (382.07 KB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Non-trainable params: 0 (0.00 B)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "29.72066307067871 0.1249179784635558\n",
- "33.5255126953125 0.07318165013466399\n",
- "38.62114071846008 -0.023410033110204745\n",
- "35.293763875961304 0.0894183448972399\n",
- "43.56008291244507 -0.023466014859413426\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "Model: \"LSTM\"\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1mModel: \"LSTM\"\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
- "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ lstm_20 (LSTM) │ (None, 45, 100) │ 48,000 │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ lstm_21 (LSTM) │ (None, 100) │ 80,400 │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_40 (Dense) │ (None, 9) │ 909 │\n",
- "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
- "
\n"
- ],
- "text/plain": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
- "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ lstm_20 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ lstm_21 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_40 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m909\u001b[0m │\n",
- "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Total params: 129,309 (505.11 KB)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,309\u001b[0m (505.11 KB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Trainable params: 129,309 (505.11 KB)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,309\u001b[0m (505.11 KB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Non-trainable params: 0 (0.00 B)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "25.49171733856201 -0.3036703338256579\n",
- "27.654677867889404 -0.14078671841214707\n",
- "22.398143529891968 -0.0803994153254022\n",
- "25.16021418571472 -0.06587595121316724\n",
- "23.790441751480103 -0.025310879036213083\n",
- "R2 scores\n",
- "Means:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.093291 | \n",
- " 0.078430 | \n",
- " -0.231738 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.096005 | \n",
- " 0.051663 | \n",
- " -0.362725 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 0.111538 | \n",
- " 0.048128 | \n",
- " -0.123209 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.093291 0.078430 -0.231738\n",
- "12 0.096005 0.051663 -0.362725\n",
- "9 0.111538 0.048128 -0.123209\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.067705 | \n",
- " 0.068194 | \n",
- " 0.078686 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.067321 | \n",
- " 0.068831 | \n",
- " 0.082159 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 0.066717 | \n",
- " 0.068882 | \n",
- " 0.074797 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.067705 0.068194 0.078686\n",
- "12 0.067321 0.068831 0.082159\n",
- "9 0.066717 0.068882 0.074797\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Std:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.005539 | \n",
- " 0.015925 | \n",
- " 0.103254 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.004060 | \n",
- " 0.056263 | \n",
- " 0.219555 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 0.003457 | \n",
- " 0.060783 | \n",
- " 0.097549 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.005539 0.015925 0.103254\n",
- "12 0.004060 0.056263 0.219555\n",
- "9 0.003457 0.060783 0.097549\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.000215 | \n",
- " 0.000578 | \n",
- " 0.003284 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.000148 | \n",
- " 0.001978 | \n",
- " 0.006539 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 0.000131 | \n",
- " 0.002142 | \n",
- " 0.003162 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.000215 0.000578 0.003284\n",
- "12 0.000148 0.001978 0.006539\n",
- "9 0.000131 0.002142 0.003162\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Training Times\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 148.763068 | \n",
- " 42.478296 | \n",
- " 33.661120 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 131.653231 | \n",
- " 41.333987 | \n",
- " 31.695352 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 112.193632 | \n",
- " 36.144233 | \n",
- " 24.899039 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 148.763068 42.478296 33.661120\n",
- "12 131.653231 41.333987 31.695352\n",
- "9 112.193632 36.144233 24.899039\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 14.674705 | \n",
- " 2.705071 | \n",
- " 1.335022 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 6.613783 | \n",
- " 4.486661 | \n",
- " 3.014970 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 16.354742 | \n",
- " 4.689843 | \n",
- " 1.760482 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 14.674705 2.705071 1.335022\n",
- "12 6.613783 4.486661 3.014970\n",
- "9 16.354742 4.689843 1.760482\n",
- "6 NaN NaN NaN\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "Model: \"TKAN\"\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1mModel: \"TKAN\"\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
- "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ tkan_30 (TKAN) │ (None, 45, 100) │ 41,750 │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ tkan_31 (TKAN) │ (None, 100) │ 67,670 │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_45 (Dense) │ (None, 6) │ 606 │\n",
- "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
- "
\n"
- ],
- "text/plain": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
- "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ tkan_30 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,750\u001b[0m │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ tkan_31 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_45 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m606\u001b[0m │\n",
- "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Total params: 110,026 (429.79 KB)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m110,026\u001b[0m (429.79 KB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Trainable params: 110,006 (429.71 KB)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m110,006\u001b[0m (429.71 KB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Non-trainable params: 20 (80.00 B)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "77.3258445262909 0.1280814725052942\n",
- "114.44928884506226 0.14017116370653385\n",
- "104.74958324432373 0.126116634620875\n",
- "101.73448276519775 0.13543281159237028\n",
- "88.49818539619446 0.13133054407465156\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "Model: \"GRU\"\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1mModel: \"GRU\"\u001b[0m\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
- "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ gru_30 (GRU) │ (None, 45, 100) │ 36,300 │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ gru_31 (GRU) │ (None, 100) │ 60,600 │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_50 (Dense) │ (None, 6) │ 606 │\n",
- "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
- "
\n"
- ],
- "text/plain": [
- "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
- "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
- "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ gru_30 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ gru_31 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n",
- "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_50 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m606\u001b[0m │\n",
- "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Total params: 97,506 (380.88 KB)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,506\u001b[0m (380.88 KB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Trainable params: 97,506 (380.88 KB)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,506\u001b[0m (380.88 KB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Non-trainable params: 0 (0.00 B)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "40.4565007686615 0.13588272560079817\n",
- "32.21065616607666 0.150907506214383\n",
- "41.915128231048584 0.09900284704797373\n",
- "40.109740018844604 0.06818254449404637\n",
- "40.55113649368286 0.14480364148103195\n"
+ "26.26027750968933 0.1602740337153421\n",
+ "30.87123155593872 0.12435047691678952\n",
+ "30.425692319869995 0.10110813722797496\n",
+ "29.073087215423584 0.13987663037711742\n",
+ "27.583446502685547 0.12691294524178082\n",
+ "30.0588481426239 0.09291668323995494\n",
+ "30.896180629730225 0.13374339087226397\n",
+ "26.069262266159058 0.13703060900398636\n",
+ "26.07837224006653 0.14521618469459105\n",
+ "28.18198537826538 0.13697203267595567\n"
]
},
{
@@ -3422,11 +1468,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ lstm_30 (LSTM) │ (None, 45, 100) │ 48,000 │\n",
+ "│ lstm_40 (LSTM) │ (None, 45, 100) │ 48,000 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ lstm_31 (LSTM) │ (None, 100) │ 80,400 │\n",
+ "│ lstm_41 (LSTM) │ (None, 100) │ 80,400 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_55 (Dense) │ (None, 6) │ 606 │\n",
+ "│ dense_80 (Dense) │ (None, 6) │ 606 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"
\n"
],
@@ -3434,435 +1480,51 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ lstm_30 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n",
+ "│ lstm_40 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ lstm_31 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n",
+ "│ lstm_41 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_55 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m606\u001b[0m │\n",
+ "│ dense_80 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m6\u001b[0m) │ \u001b[38;5;34m606\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Total params: 129,006 (503.93 KB)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,006\u001b[0m (503.93 KB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Trainable params: 129,006 (503.93 KB)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,006\u001b[0m (503.93 KB)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- " Non-trainable params: 0 (0.00 B)\n",
- "
\n"
- ],
- "text/plain": [
- "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "23.99752426147461 0.03173612065905046\n",
- "25.493926286697388 -0.05948194958746694\n",
- "33.9563672542572 0.12418423914604533\n",
- "25.076443910598755 -0.09484002694483258\n",
- "23.51260232925415 -0.021994532852705695\n",
- "R2 scores\n",
- "Means:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.093291 | \n",
- " 0.078430 | \n",
- " -0.231738 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.096005 | \n",
- " 0.051663 | \n",
- " -0.362725 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 0.111538 | \n",
- " 0.048128 | \n",
- " -0.123209 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 0.132227 | \n",
- " 0.119756 | \n",
- " -0.004079 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.093291 0.078430 -0.231738\n",
- "12 0.096005 0.051663 -0.362725\n",
- "9 0.111538 0.048128 -0.123209\n",
- "6 0.132227 0.119756 -0.004079\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.067705 | \n",
- " 0.068194 | \n",
- " 0.078686 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.067321 | \n",
- " 0.068831 | \n",
- " 0.082159 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 0.066717 | \n",
- " 0.068882 | \n",
- " 0.074797 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 0.066131 | \n",
- " 0.066470 | \n",
- " 0.070901 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.067705 0.068194 0.078686\n",
- "12 0.067321 0.068831 0.082159\n",
- "9 0.066717 0.068882 0.074797\n",
- "6 0.066131 0.066470 0.070901\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Std:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.005539 | \n",
- " 0.015925 | \n",
- " 0.103254 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.004060 | \n",
- " 0.056263 | \n",
- " 0.219555 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 0.003457 | \n",
- " 0.060783 | \n",
- " 0.097549 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 0.005074 | \n",
- " 0.031459 | \n",
- " 0.076632 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " Total params: 129,006 (503.93 KB)\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.005539 0.015925 0.103254\n",
- "12 0.004060 0.056263 0.219555\n",
- "9 0.003457 0.060783 0.097549\n",
- "6 0.005074 0.031459 0.076632\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,006\u001b[0m (503.93 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
+ "data": {
+ "text/html": [
+ " Trainable params: 129,006 (503.93 KB)\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,006\u001b[0m (503.93 KB)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.000215 | \n",
- " 0.000578 | \n",
- " 0.003284 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.000148 | \n",
- " 0.001978 | \n",
- " 0.006539 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 0.000131 | \n",
- " 0.002142 | \n",
- " 0.003162 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 0.000199 | \n",
- " 0.001151 | \n",
- " 0.002695 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ " Non-trainable params: 0 (0.00 B)\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.000215 0.000578 0.003284\n",
- "12 0.000148 0.001978 0.006539\n",
- "9 0.000131 0.002142 0.003162\n",
- "6 0.000199 0.001151 0.002695\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
@@ -3872,186 +1534,81 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Training Times\n"
+ "25.56489086151123 0.07762226752057655\n",
+ "17.17194938659668 -0.030710740566574007\n",
+ "22.08793616294861 0.12124885788661292\n",
+ "18.8874089717865 -0.19992596769356452\n",
+ "18.460254669189453 -0.03568727060662086\n",
+ "19.47718334197998 -0.1732379052017987\n",
+ "18.813477754592896 -0.0515297041407121\n",
+ "18.96723985671997 -0.00603457861723435\n",
+ "17.978565216064453 0.06545331792210407\n",
+ "18.76132583618164 -0.21833925712011495\n"
]
},
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
+ "data": {
+ "text/html": [
+ "Model: \"TKAN\"\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1mModel: \"TKAN\"\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 148.763068 | \n",
- " 42.478296 | \n",
- " 33.661120 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 131.653231 | \n",
- " 41.333987 | \n",
- " 31.695352 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 112.193632 | \n",
- " 36.144233 | \n",
- " 24.899039 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 97.351477 | \n",
- " 39.048632 | \n",
- " 26.407373 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+ "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+ "│ tkan_60 (TKAN) │ (None, 45, 100) │ 41,316 │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ tkan_61 (TKAN) │ (None, 100) │ 67,670 │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ dense_90 (Dense) │ (None, 9) │ 909 │\n",
+ "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 148.763068 42.478296 33.661120\n",
- "12 131.653231 41.333987 31.695352\n",
- "9 112.193632 36.144233 24.899039\n",
- "6 97.351477 39.048632 26.407373\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+ "│ tkan_60 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,316\u001b[0m │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ tkan_61 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ dense_90 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m909\u001b[0m │\n",
+ "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
+ "data": {
+ "text/html": [
+ " Total params: 109,895 (429.28 KB)\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m109,895\u001b[0m (429.28 KB)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 14.674705 | \n",
- " 2.705071 | \n",
- " 1.335022 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 6.613783 | \n",
- " 4.486661 | \n",
- " 3.014970 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 16.354742 | \n",
- " 4.689843 | \n",
- " 1.760482 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 13.005265 | \n",
- " 3.473911 | \n",
- " 3.841358 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ " Trainable params: 109,875 (429.20 KB)\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 14.674705 2.705071 1.335022\n",
- "12 6.613783 4.486661 3.014970\n",
- "9 16.354742 4.689843 1.760482\n",
- "6 13.005265 3.473911 3.841358\n",
- "3 NaN NaN NaN\n",
- "1 NaN NaN NaN"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m109,875\u001b[0m (429.20 KB)\n"
]
},
"metadata": {},
@@ -4060,11 +1617,40 @@
{
"data": {
"text/html": [
- "Model: \"TKAN\"\n",
+ " Non-trainable params: 20 (80.00 B)\n",
"
\n"
],
"text/plain": [
- "\u001b[1mModel: \"TKAN\"\u001b[0m\n"
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "79.56117129325867 0.10413605110577702\n",
+ "95.19019293785095 0.10785571960353318\n",
+ "62.058950901031494 0.1022740221679086\n",
+ "57.313427448272705 0.09659038594513981\n",
+ "95.39828181266785 0.11798299575914914\n",
+ "109.50233817100525 0.11471786200051105\n",
+ "91.66243028640747 0.10141240448038549\n",
+ "83.4484133720398 0.11271609311366479\n",
+ "80.45564818382263 0.10729036960326731\n",
+ "89.72633719444275 0.10601871726343631\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "Model: \"GRU\"\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1mModel: \"GRU\"\u001b[0m\n"
]
},
"metadata": {},
@@ -4076,11 +1662,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ tkan_40 (TKAN) │ (None, 45, 100) │ 41,750 │\n",
+ "│ gru_60 (GRU) │ (None, 45, 100) │ 36,300 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ tkan_41 (TKAN) │ (None, 100) │ 67,670 │\n",
+ "│ gru_61 (GRU) │ (None, 100) │ 60,600 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_60 (Dense) │ (None, 3) │ 303 │\n",
+ "│ dense_100 (Dense) │ (None, 9) │ 909 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"
\n"
],
@@ -4088,11 +1674,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ tkan_40 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,750\u001b[0m │\n",
+ "│ gru_60 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ tkan_41 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n",
+ "│ gru_61 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_60 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m303\u001b[0m │\n",
+ "│ dense_100 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m909\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
@@ -4102,11 +1688,11 @@
{
"data": {
"text/html": [
- " Total params: 109,723 (428.61 KB)\n",
+ " Total params: 97,809 (382.07 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m109,723\u001b[0m (428.61 KB)\n"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,809\u001b[0m (382.07 KB)\n"
]
},
"metadata": {},
@@ -4115,11 +1701,11 @@
{
"data": {
"text/html": [
- " Trainable params: 109,703 (428.53 KB)\n",
+ " Trainable params: 97,809 (382.07 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m109,703\u001b[0m (428.53 KB)\n"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,809\u001b[0m (382.07 KB)\n"
]
},
"metadata": {},
@@ -4128,11 +1714,11 @@
{
"data": {
"text/html": [
- " Non-trainable params: 20 (80.00 B)\n",
+ " Non-trainable params: 0 (0.00 B)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n"
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
@@ -4142,21 +1728,26 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "92.45071959495544 0.1916116569043532\n",
- "94.640451669693 0.19454637958748155\n",
- "73.71659445762634 0.19117099482549302\n",
- "95.40752124786377 0.19110640326041053\n",
- "85.31889510154724 0.19614372723219617\n"
+ "22.387739181518555 0.12872665220206386\n",
+ "30.9358811378479 0.07007529147404491\n",
+ "30.65935730934143 0.06094348103987125\n",
+ "22.681428909301758 0.11531462257235203\n",
+ "30.2576744556427 -0.009184511914249627\n",
+ "23.401977062225342 0.1145808233763385\n",
+ "29.867743015289307 -0.0020916084892752293\n",
+ "21.44379758834839 0.13675731715603687\n",
+ "32.31744432449341 0.08318358609980472\n",
+ "36.52803301811218 -0.02434265870166709\n"
]
},
{
"data": {
"text/html": [
- "Model: \"GRU\"\n",
+ "Model: \"LSTM\"\n",
"
\n"
],
"text/plain": [
- "\u001b[1mModel: \"GRU\"\u001b[0m\n"
+ "\u001b[1mModel: \"LSTM\"\u001b[0m\n"
]
},
"metadata": {},
@@ -4168,11 +1759,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ gru_40 (GRU) │ (None, 45, 100) │ 36,300 │\n",
+ "│ lstm_60 (LSTM) │ (None, 45, 100) │ 48,000 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ gru_41 (GRU) │ (None, 100) │ 60,600 │\n",
+ "│ lstm_61 (LSTM) │ (None, 100) │ 80,400 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_65 (Dense) │ (None, 3) │ 303 │\n",
+ "│ dense_110 (Dense) │ (None, 9) │ 909 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"
\n"
],
@@ -4180,11 +1771,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ gru_40 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n",
+ "│ lstm_60 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ gru_41 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n",
+ "│ lstm_61 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_65 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m303\u001b[0m │\n",
+ "│ dense_110 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m9\u001b[0m) │ \u001b[38;5;34m909\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
@@ -4194,11 +1785,11 @@
{
"data": {
"text/html": [
- " Total params: 97,203 (379.70 KB)\n",
+ " Total params: 129,309 (505.11 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,203\u001b[0m (379.70 KB)\n"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,309\u001b[0m (505.11 KB)\n"
]
},
"metadata": {},
@@ -4207,11 +1798,11 @@
{
"data": {
"text/html": [
- " Trainable params: 97,203 (379.70 KB)\n",
+ " Trainable params: 129,309 (505.11 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,203\u001b[0m (379.70 KB)\n"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,309\u001b[0m (505.11 KB)\n"
]
},
"metadata": {},
@@ -4234,21 +1825,26 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "34.47401285171509 0.23517545876629728\n",
- "32.61530804634094 0.23820733972132835\n",
- "38.93267202377319 0.23710632242677557\n",
- "32.61954689025879 0.23525012980273963\n",
- "33.25591826438904 0.24304388963816356\n"
+ "20.542985439300537 -0.058979852588216475\n",
+ "19.61976170539856 -0.07973171479104875\n",
+ "19.607908487319946 -0.15736860736964103\n",
+ "17.53559136390686 -0.09773126198658748\n",
+ "19.434728860855103 -0.39964713298492527\n",
+ "20.465200901031494 -0.2428380152707189\n",
+ "18.985889673233032 -0.5012525957302113\n",
+ "20.644855260849 -0.2738693890416318\n",
+ "19.245380401611328 -0.05146921829097126\n",
+ "19.828939199447632 -0.0008553086817381464\n"
]
},
{
"data": {
"text/html": [
- "Model: \"LSTM\"\n",
+ "Model: \"TKAN\"\n",
"
\n"
],
"text/plain": [
- "\u001b[1mModel: \"LSTM\"\u001b[0m\n"
+ "\u001b[1mModel: \"TKAN\"\u001b[0m\n"
]
},
"metadata": {},
@@ -4260,11 +1856,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ lstm_40 (LSTM) │ (None, 45, 100) │ 48,000 │\n",
+ "│ tkan_80 (TKAN) │ (None, 60, 100) │ 41,316 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ lstm_41 (LSTM) │ (None, 100) │ 80,400 │\n",
+ "│ tkan_81 (TKAN) │ (None, 100) │ 67,670 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_70 (Dense) │ (None, 3) │ 303 │\n",
+ "│ dense_120 (Dense) │ (None, 12) │ 1,212 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"
\n"
],
@@ -4272,11 +1868,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ lstm_40 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n",
+ "│ tkan_80 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m60\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,316\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ lstm_41 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n",
+ "│ tkan_81 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_70 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m303\u001b[0m │\n",
+ "│ dense_120 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,212\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
@@ -4286,11 +1882,11 @@
{
"data": {
"text/html": [
- " Total params: 128,703 (502.75 KB)\n",
+ " Total params: 110,198 (430.46 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m128,703\u001b[0m (502.75 KB)\n"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m110,198\u001b[0m (430.46 KB)\n"
]
},
"metadata": {},
@@ -4299,11 +1895,11 @@
{
"data": {
"text/html": [
- " Trainable params: 128,703 (502.75 KB)\n",
+ " Trainable params: 110,178 (430.38 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m128,703\u001b[0m (502.75 KB)\n"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m110,178\u001b[0m (430.38 KB)\n"
]
},
"metadata": {},
@@ -4312,11 +1908,11 @@
{
"data": {
"text/html": [
- " Non-trainable params: 0 (0.00 B)\n",
+ " Non-trainable params: 20 (80.00 B)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m20\u001b[0m (80.00 B)\n"
]
},
"metadata": {},
@@ -4326,381 +1922,94 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "27.456106424331665 0.05745848167891302\n",
- "25.54291844367981 0.04933120546474171\n",
- "27.04043936729431 0.14269677898499963\n",
- "24.304097175598145 0.16428378453298487\n",
- "26.050607442855835 0.13263322060964924\n",
- "R2 scores\n",
- "Means:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
+ "102.03934359550476 0.0950019998069378\n",
+ "121.27882838249207 0.10264306639669311\n",
+ "108.56248140335083 0.0990784027938416\n",
+ "77.42332911491394 0.0915531316015012\n",
+ "108.50404167175293 0.09352464587821978\n",
+ "95.75039339065552 0.092634008105403\n",
+ "102.65606546401978 0.09820662504181471\n",
+ "130.8202440738678 0.09765221932195063\n",
+ "113.90040755271912 0.09989627016065132\n",
+ "118.27212023735046 0.1041245428645089\n"
]
},
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.093291 | \n",
- " 0.078430 | \n",
- " -0.231738 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.096005 | \n",
- " 0.051663 | \n",
- " -0.362725 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 0.111538 | \n",
- " 0.048128 | \n",
- " -0.123209 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 0.132227 | \n",
- " 0.119756 | \n",
- " -0.004079 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 0.192916 | \n",
- " 0.237757 | \n",
- " 0.109281 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "Model: \"GRU\"\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.093291 0.078430 -0.231738\n",
- "12 0.096005 0.051663 -0.362725\n",
- "9 0.111538 0.048128 -0.123209\n",
- "6 0.132227 0.119756 -0.004079\n",
- "3 0.192916 0.237757 0.109281\n",
- "1 NaN NaN NaN"
+ "\u001b[1mModel: \"GRU\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.067705 | \n",
- " 0.068194 | \n",
- " 0.078686 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.067321 | \n",
- " 0.068831 | \n",
- " 0.082159 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 0.066717 | \n",
- " 0.068882 | \n",
- " 0.074797 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 0.066131 | \n",
- " 0.066470 | \n",
- " 0.070901 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 0.063560 | \n",
- " 0.061688 | \n",
- " 0.066594 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+ "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+ "│ gru_80 (GRU) │ (None, 60, 100) │ 36,300 │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ gru_81 (GRU) │ (None, 100) │ 60,600 │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ dense_130 (Dense) │ (None, 12) │ 1,212 │\n",
+ "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.067705 0.068194 0.078686\n",
- "12 0.067321 0.068831 0.082159\n",
- "9 0.066717 0.068882 0.074797\n",
- "6 0.066131 0.066470 0.070901\n",
- "3 0.063560 0.061688 0.066594\n",
- "1 NaN NaN NaN"
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+ "│ gru_80 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m60\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ gru_81 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ dense_130 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,212\u001b[0m │\n",
+ "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Std:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.005539 | \n",
- " 0.015925 | \n",
- " 0.103254 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.004060 | \n",
- " 0.056263 | \n",
- " 0.219555 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 0.003457 | \n",
- " 0.060783 | \n",
- " 0.097549 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 0.005074 | \n",
- " 0.031459 | \n",
- " 0.076632 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 0.002054 | \n",
- " 0.002882 | \n",
- " 0.046833 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ " Total params: 98,112 (383.25 KB)\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.005539 0.015925 0.103254\n",
- "12 0.004060 0.056263 0.219555\n",
- "9 0.003457 0.060783 0.097549\n",
- "6 0.005074 0.031459 0.076632\n",
- "3 0.002054 0.002882 0.046833\n",
- "1 NaN NaN NaN"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m98,112\u001b[0m (383.25 KB)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
- },
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 0.000215 | \n",
- " 0.000578 | \n",
- " 0.003284 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.000148 | \n",
- " 0.001978 | \n",
- " 0.006539 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 0.000131 | \n",
- " 0.002142 | \n",
- " 0.003162 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 0.000199 | \n",
- " 0.001151 | \n",
- " 0.002695 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 0.000094 | \n",
- " 0.000112 | \n",
- " 0.001722 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ " Trainable params: 98,112 (383.25 KB)\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m98,112\u001b[0m (383.25 KB)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " Non-trainable params: 0 (0.00 B)\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 0.000215 0.000578 0.003284\n",
- "12 0.000148 0.001978 0.006539\n",
- "9 0.000131 0.002142 0.003162\n",
- "6 0.000199 0.001151 0.002695\n",
- "3 0.000094 0.000112 0.001722\n",
- "1 NaN NaN NaN"
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
@@ -4710,191 +2019,115 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Training Times\n"
+ "35.47221755981445 0.04088546661722716\n",
+ "36.82418465614319 0.03298667710271531\n",
+ "29.090290307998657 0.08667586160493074\n",
+ "27.488890886306763 0.09733002632430683\n",
+ "30.277069330215454 0.10445165105456156\n",
+ "26.764411687850952 0.11400616888042837\n",
+ "35.36870241165161 0.02046360232907878\n",
+ "26.94648003578186 0.11067681945648022\n",
+ "35.716086864471436 0.02628936001538192\n",
+ "32.919156312942505 0.05842798641178421\n"
]
},
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
- " return _methods._mean(a, axis=axis, dtype=dtype,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:129: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
+ "data": {
+ "text/html": [
+ "Model: \"LSTM\"\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1mModel: \"LSTM\"\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 148.763068 | \n",
- " 42.478296 | \n",
- " 33.661120 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 131.653231 | \n",
- " 41.333987 | \n",
- " 31.695352 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 112.193632 | \n",
- " 36.144233 | \n",
- " 24.899039 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 97.351477 | \n",
- " 39.048632 | \n",
- " 26.407373 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 88.306836 | \n",
- " 34.379492 | \n",
- " 26.078834 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+ "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+ "│ lstm_80 (LSTM) │ (None, 60, 100) │ 48,000 │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ lstm_81 (LSTM) │ (None, 100) │ 80,400 │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ dense_140 (Dense) │ (None, 12) │ 1,212 │\n",
+ "└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 148.763068 42.478296 33.661120\n",
- "12 131.653231 41.333987 31.695352\n",
- "9 112.193632 36.144233 24.899039\n",
- "6 97.351477 39.048632 26.407373\n",
- "3 88.306836 34.379492 26.078834\n",
- "1 NaN NaN NaN"
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
+ "│ lstm_80 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m60\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ lstm_81 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n",
+ "├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
+ "│ dense_140 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m12\u001b[0m) │ \u001b[38;5;34m1,212\u001b[0m │\n",
+ "└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:206: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
- " ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:163: RuntimeWarning: invalid value encountered in divide\n",
- " arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
- "/usr/local/lib/python3.11/dist-packages/numpy/core/_methods.py:198: RuntimeWarning: invalid value encountered in scalar divide\n",
- " ret = ret.dtype.type(ret / rcount)\n"
- ]
+ "data": {
+ "text/html": [
+ " Total params: 129,612 (506.30 KB)\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,612\u001b[0m (506.30 KB)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
"data": {
"text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " TKAN | \n",
- " GRU | \n",
- " LSTM | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 15 | \n",
- " 14.674705 | \n",
- " 2.705071 | \n",
- " 1.335022 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 6.613783 | \n",
- " 4.486661 | \n",
- " 3.014970 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 16.354742 | \n",
- " 4.689843 | \n",
- " 1.760482 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 13.005265 | \n",
- " 3.473911 | \n",
- " 3.841358 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 8.117994 | \n",
- " 2.375397 | \n",
- " 1.118862 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " NaN | \n",
- " NaN | \n",
- " NaN | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
+ " Trainable params: 129,612 (506.30 KB)\n",
+ "
\n"
],
"text/plain": [
- " TKAN GRU LSTM\n",
- "15 14.674705 2.705071 1.335022\n",
- "12 6.613783 4.486661 3.014970\n",
- "9 16.354742 4.689843 1.760482\n",
- "6 13.005265 3.473911 3.841358\n",
- "3 8.117994 2.375397 1.118862\n",
- "1 NaN NaN NaN"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,612\u001b[0m (506.30 KB)\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ " Non-trainable params: 0 (0.00 B)\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "26.069539546966553 -0.36695473176783655\n",
+ "21.807262659072876 -0.25014631415290206\n",
+ "20.902458667755127 -0.11645941888847307\n",
+ "22.4891197681427 -0.052407595508665694\n",
+ "22.018208742141724 -0.1754222997379895\n",
+ "23.925899028778076 -0.444045746274295\n",
+ "23.6626615524292 -0.2362643247514915\n",
+ "20.98034143447876 -0.16545529402468354\n",
+ "22.008980751037598 -0.23959377820038377\n",
+ "21.912462949752808 -0.5002557724294028\n"
+ ]
+ },
{
"data": {
"text/html": [
@@ -4914,11 +2147,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ tkan_50 (TKAN) │ (None, 45, 100) │ 41,750 │\n",
+ "│ tkan_100 (TKAN) │ (None, 75, 100) │ 41,316 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ tkan_51 (TKAN) │ (None, 100) │ 67,670 │\n",
+ "│ tkan_101 (TKAN) │ (None, 100) │ 67,670 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_75 (Dense) │ (None, 1) │ 101 │\n",
+ "│ dense_150 (Dense) │ (None, 15) │ 1,515 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"
\n"
],
@@ -4926,11 +2159,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ tkan_50 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,750\u001b[0m │\n",
+ "│ tkan_100 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m75\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m41,316\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ tkan_51 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n",
+ "│ tkan_101 (\u001b[38;5;33mTKAN\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m67,670\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_75 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m101\u001b[0m │\n",
+ "│ dense_150 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m) │ \u001b[38;5;34m1,515\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
@@ -4940,11 +2173,11 @@
{
"data": {
"text/html": [
- " Total params: 109,521 (427.82 KB)\n",
+ " Total params: 110,501 (431.64 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m109,521\u001b[0m (427.82 KB)\n"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m110,501\u001b[0m (431.64 KB)\n"
]
},
"metadata": {},
@@ -4953,11 +2186,11 @@
{
"data": {
"text/html": [
- " Trainable params: 109,501 (427.74 KB)\n",
+ " Trainable params: 110,481 (431.57 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m109,501\u001b[0m (427.74 KB)\n"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m110,481\u001b[0m (431.57 KB)\n"
]
},
"metadata": {},
@@ -4980,11 +2213,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "81.33754324913025 0.3022592965618307\n",
- "104.95103526115417 0.302846112010449\n",
- "111.78685140609741 0.33146529199419217\n",
- "64.58520579338074 0.27988703845286933\n",
- "115.72550010681152 0.32013992233948596\n"
+ "108.7913281917572 0.09994023646490877\n",
+ "156.19587421417236 0.08586536769033742\n",
+ "151.76872277259827 0.0953234353911566\n",
+ "109.25939583778381 0.08350316914250061\n",
+ "120.67966675758362 0.09383066359946372\n",
+ "118.93469595909119 0.09027677327433754\n",
+ "117.88046956062317 0.09486514143615367\n",
+ "106.22660422325134 0.09703190398065814\n",
+ "154.6233696937561 0.09741707472392\n",
+ "124.99523186683655 0.09227563879811171\n"
]
},
{
@@ -5006,11 +2244,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ gru_50 (GRU) │ (None, 45, 100) │ 36,300 │\n",
+ "│ gru_100 (GRU) │ (None, 75, 100) │ 36,300 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ gru_51 (GRU) │ (None, 100) │ 60,600 │\n",
+ "│ gru_101 (GRU) │ (None, 100) │ 60,600 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_80 (Dense) │ (None, 1) │ 101 │\n",
+ "│ dense_160 (Dense) │ (None, 15) │ 1,515 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"
\n"
],
@@ -5018,11 +2256,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ gru_50 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n",
+ "│ gru_100 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m75\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m36,300\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ gru_51 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n",
+ "│ gru_101 (\u001b[38;5;33mGRU\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m60,600\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_80 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m101\u001b[0m │\n",
+ "│ dense_160 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m) │ \u001b[38;5;34m1,515\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
@@ -5032,11 +2270,11 @@
{
"data": {
"text/html": [
- " Total params: 97,001 (378.91 KB)\n",
+ " Total params: 98,415 (384.43 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m97,001\u001b[0m (378.91 KB)\n"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m98,415\u001b[0m (384.43 KB)\n"
]
},
"metadata": {},
@@ -5045,11 +2283,11 @@
{
"data": {
"text/html": [
- " Trainable params: 97,001 (378.91 KB)\n",
+ " Trainable params: 98,415 (384.43 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m97,001\u001b[0m (378.91 KB)\n"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m98,415\u001b[0m (384.43 KB)\n"
]
},
"metadata": {},
@@ -5072,11 +2310,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "42.188021659851074 0.39814672018793895\n",
- "32.78909373283386 0.3956099221775101\n",
- "35.24668765068054 0.3982762017641819\n",
- "42.99339699745178 0.3956046326974988\n",
- "28.0660617351532 0.3989501248082734\n"
+ "41.711098432540894 0.04633579312704695\n",
+ "37.61102104187012 0.041959092296909146\n",
+ "33.6794171333313 0.08124711803653464\n",
+ "41.39418888092041 0.062273110060791455\n",
+ "33.090798139572144 0.09644091868374302\n",
+ "34.905259132385254 0.0688485348099919\n",
+ "31.54578924179077 0.09070480771844083\n",
+ "31.649243354797363 0.08222696375881307\n",
+ "34.72465991973877 0.07408738048462013\n",
+ "34.364349365234375 0.06764429477726015\n"
]
},
{
@@ -5098,11 +2341,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ lstm_50 (LSTM) │ (None, 45, 100) │ 48,000 │\n",
+ "│ lstm_100 (LSTM) │ (None, 75, 100) │ 48,000 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ lstm_51 (LSTM) │ (None, 100) │ 80,400 │\n",
+ "│ lstm_101 (LSTM) │ (None, 100) │ 80,400 │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_85 (Dense) │ (None, 1) │ 101 │\n",
+ "│ dense_170 (Dense) │ (None, 15) │ 1,515 │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n",
"
\n"
],
@@ -5110,11 +2353,11 @@
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
- "│ lstm_50 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m45\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n",
+ "│ lstm_100 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m75\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m48,000\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ lstm_51 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n",
+ "│ lstm_101 (\u001b[38;5;33mLSTM\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m100\u001b[0m) │ \u001b[38;5;34m80,400\u001b[0m │\n",
"├─────────────────────────────────┼────────────────────────┼───────────────┤\n",
- "│ dense_85 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m101\u001b[0m │\n",
+ "│ dense_170 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m15\u001b[0m) │ \u001b[38;5;34m1,515\u001b[0m │\n",
"└─────────────────────────────────┴────────────────────────┴───────────────┘\n"
]
},
@@ -5124,11 +2367,11 @@
{
"data": {
"text/html": [
- " Total params: 128,501 (501.96 KB)\n",
+ " Total params: 129,915 (507.48 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m128,501\u001b[0m (501.96 KB)\n"
+ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m129,915\u001b[0m (507.48 KB)\n"
]
},
"metadata": {},
@@ -5137,11 +2380,11 @@
{
"data": {
"text/html": [
- " Trainable params: 128,501 (501.96 KB)\n",
+ " Trainable params: 129,915 (507.48 KB)\n",
"
\n"
],
"text/plain": [
- "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m128,501\u001b[0m (501.96 KB)\n"
+ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m129,915\u001b[0m (507.48 KB)\n"
]
},
"metadata": {},
@@ -5164,11 +2407,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "25.248199701309204 0.3954618995452197\n",
- "25.895612478256226 0.36461360004863774\n",
- "22.943801403045654 0.39227780049103544\n",
- "25.566108465194702 0.3725527679003343\n",
- "24.88447666168213 0.3704866575757282\n",
+ "27.019399642944336 -0.02014558059192651\n",
+ "24.994324684143066 -0.034652205897714984\n",
+ "29.875847578048706 -0.2883331388024311\n",
+ "27.68077278137207 -0.19307423134624493\n",
+ "32.10626721382141 -0.1967002645401843\n",
+ "26.022549629211426 -0.10374175022709843\n",
+ "28.117710828781128 -0.0705739546737177\n",
+ "30.498241662979126 -0.25929543857022325\n",
+ "24.907368898391724 -0.09258147039272031\n",
+ "26.723052978515625 -0.15915120445580425\n",
"R2 scores\n",
"Means:\n"
]
@@ -5201,40 +2449,40 @@
" \n",
" \n",
" \n",
- " 15 | \n",
- " 0.093291 | \n",
- " 0.078430 | \n",
- " -0.231738 | \n",
+ " 1 | \n",
+ " 0.310078 | \n",
+ " 0.394863 | \n",
+ " 0.381780 | \n",
"
\n",
" \n",
- " 12 | \n",
- " 0.096005 | \n",
- " 0.051663 | \n",
- " -0.362725 | \n",
+ " 3 | \n",
+ " 0.190549 | \n",
+ " 0.236880 | \n",
+ " 0.113571 | \n",
"
\n",
" \n",
- " 9 | \n",
- " 0.111538 | \n",
- " 0.048128 | \n",
- " -0.123209 | \n",
+ " 6 | \n",
+ " 0.128743 | \n",
+ " 0.129840 | \n",
+ " -0.045114 | \n",
"
\n",
" \n",
- " 6 | \n",
- " 0.132227 | \n",
- " 0.119756 | \n",
- " -0.004079 | \n",
+ " 9 | \n",
+ " 0.107099 | \n",
+ " 0.067396 | \n",
+ " -0.186374 | \n",
"
\n",
" \n",
- " 3 | \n",
- " 0.192916 | \n",
- " 0.237757 | \n",
- " 0.109281 | \n",
+ " 12 | \n",
+ " 0.097431 | \n",
+ " 0.069219 | \n",
+ " -0.254701 | \n",
"
\n",
" \n",
- " 1 | \n",
- " 0.307320 | \n",
- " 0.397318 | \n",
- " 0.379079 | \n",
+ " 15 | \n",
+ " 0.093033 | \n",
+ " 0.071177 | \n",
+ " -0.141825 | \n",
"
\n",
" \n",
"\n",
@@ -5242,12 +2490,12 @@
],
"text/plain": [
" TKAN GRU LSTM\n",
- "15 0.093291 0.078430 -0.231738\n",
- "12 0.096005 0.051663 -0.362725\n",
- "9 0.111538 0.048128 -0.123209\n",
- "6 0.132227 0.119756 -0.004079\n",
- "3 0.192916 0.237757 0.109281\n",
- "1 0.307320 0.397318 0.379079"
+ "1 0.310078 0.394863 0.381780\n",
+ "3 0.190549 0.236880 0.113571\n",
+ "6 0.128743 0.129840 -0.045114\n",
+ "9 0.107099 0.067396 -0.186374\n",
+ "12 0.097431 0.069219 -0.254701\n",
+ "15 0.093033 0.071177 -0.141825"
]
},
"metadata": {},
@@ -5281,40 +2529,40 @@
" \n",
" \n",
" \n",
- " 15 | \n",
- " 0.067705 | \n",
- " 0.068194 | \n",
- " 0.078686 | \n",
+ " 1 | \n",
+ " 0.058833 | \n",
+ " 0.055101 | \n",
+ " 0.055693 | \n",
"
\n",
" \n",
- " 12 | \n",
- " 0.067321 | \n",
- " 0.068831 | \n",
- " 0.082159 | \n",
+ " 3 | \n",
+ " 0.063659 | \n",
+ " 0.061720 | \n",
+ " 0.066402 | \n",
"
\n",
" \n",
- " 9 | \n",
- " 0.066717 | \n",
- " 0.068882 | \n",
- " 0.074797 | \n",
+ " 6 | \n",
+ " 0.066264 | \n",
+ " 0.066109 | \n",
+ " 0.072302 | \n",
"
\n",
" \n",
- " 6 | \n",
- " 0.066131 | \n",
- " 0.066470 | \n",
- " 0.070901 | \n",
+ " 9 | \n",
+ " 0.066893 | \n",
+ " 0.068206 | \n",
+ " 0.076763 | \n",
"
\n",
" \n",
- " 3 | \n",
- " 0.063560 | \n",
- " 0.061688 | \n",
- " 0.066594 | \n",
+ " 12 | \n",
+ " 0.067267 | \n",
+ " 0.068214 | \n",
+ " 0.079032 | \n",
"
\n",
" \n",
- " 1 | \n",
- " 0.058948 | \n",
- " 0.054990 | \n",
- " 0.055813 | \n",
+ " 15 | \n",
+ " 0.067711 | \n",
+ " 0.068456 | \n",
+ " 0.075833 | \n",
"
\n",
" \n",
"\n",
@@ -5322,12 +2570,12 @@
],
"text/plain": [
" TKAN GRU LSTM\n",
- "15 0.067705 0.068194 0.078686\n",
- "12 0.067321 0.068831 0.082159\n",
- "9 0.066717 0.068882 0.074797\n",
- "6 0.066131 0.066470 0.070901\n",
- "3 0.063560 0.061688 0.066594\n",
- "1 0.058948 0.054990 0.055813"
+ "1 0.058833 0.055101 0.055693\n",
+ "3 0.063659 0.061720 0.066402\n",
+ "6 0.066264 0.066109 0.072302\n",
+ "9 0.066893 0.068206 0.076763\n",
+ "12 0.067267 0.068214 0.079032\n",
+ "15 0.067711 0.068456 0.075833"
]
},
"metadata": {},
@@ -5368,40 +2616,40 @@
" \n",
" \n",
" \n",
- " 15 | \n",
- " 0.005539 | \n",
- " 0.015925 | \n",
- " 0.103254 | \n",
+ " 1 | \n",
+ " 0.013617 | \n",
+ " 0.007498 | \n",
+ " 0.009371 | \n",
"
\n",
" \n",
- " 12 | \n",
- " 0.004060 | \n",
- " 0.056263 | \n",
- " 0.219555 | \n",
+ " 3 | \n",
+ " 0.007820 | \n",
+ " 0.006761 | \n",
+ " 0.063202 | \n",
"
\n",
" \n",
- " 9 | \n",
- " 0.003457 | \n",
- " 0.060783 | \n",
- " 0.097549 | \n",
+ " 6 | \n",
+ " 0.008831 | \n",
+ " 0.018997 | \n",
+ " 0.112792 | \n",
"
\n",
" \n",
- " 6 | \n",
- " 0.005074 | \n",
- " 0.031459 | \n",
- " 0.076632 | \n",
+ " 9 | \n",
+ " 0.006202 | \n",
+ " 0.057011 | \n",
+ " 0.156331 | \n",
"
\n",
" \n",
- " 3 | \n",
- " 0.002054 | \n",
- " 0.002882 | \n",
- " 0.046833 | \n",
+ " 12 | \n",
+ " 0.004002 | \n",
+ " 0.035374 | \n",
+ " 0.135507 | \n",
"
\n",
" \n",
- " 1 | \n",
- " 0.017581 | \n",
- " 0.001423 | \n",
- " 0.012396 | \n",
+ " 15 | \n",
+ " 0.004925 | \n",
+ " 0.016791 | \n",
+ " 0.087433 | \n",
"
\n",
" \n",
"\n",
@@ -5409,12 +2657,12 @@
],
"text/plain": [
" TKAN GRU LSTM\n",
- "15 0.005539 0.015925 0.103254\n",
- "12 0.004060 0.056263 0.219555\n",
- "9 0.003457 0.060783 0.097549\n",
- "6 0.005074 0.031459 0.076632\n",
- "3 0.002054 0.002882 0.046833\n",
- "1 0.017581 0.001423 0.012396"
+ "1 0.013617 0.007498 0.009371\n",
+ "3 0.007820 0.006761 0.063202\n",
+ "6 0.008831 0.018997 0.112792\n",
+ "9 0.006202 0.057011 0.156331\n",
+ "12 0.004002 0.035374 0.135507\n",
+ "15 0.004925 0.016791 0.087433"
]
},
"metadata": {},
@@ -5448,40 +2696,40 @@
" \n",
" \n",
" \n",
- " 15 | \n",
- " 0.000215 | \n",
- " 0.000578 | \n",
- " 0.003284 | \n",
+ " 1 | \n",
+ " 0.000581 | \n",
+ " 0.000340 | \n",
+ " 0.000421 | \n",
"
\n",
" \n",
- " 12 | \n",
- " 0.000148 | \n",
- " 0.001978 | \n",
- " 0.006539 | \n",
+ " 3 | \n",
+ " 0.000316 | \n",
+ " 0.000273 | \n",
+ " 0.002282 | \n",
"
\n",
" \n",
- " 9 | \n",
- " 0.000131 | \n",
- " 0.002142 | \n",
- " 0.003162 | \n",
+ " 6 | \n",
+ " 0.000347 | \n",
+ " 0.000707 | \n",
+ " 0.003831 | \n",
"
\n",
" \n",
- " 6 | \n",
- " 0.000199 | \n",
- " 0.001151 | \n",
- " 0.002695 | \n",
+ " 9 | \n",
+ " 0.000237 | \n",
+ " 0.002033 | \n",
+ " 0.004908 | \n",
"
\n",
" \n",
- " 3 | \n",
- " 0.000094 | \n",
- " 0.000112 | \n",
- " 0.001722 | \n",
+ " 12 | \n",
+ " 0.000157 | \n",
+ " 0.001278 | \n",
+ " 0.004174 | \n",
"
\n",
" \n",
- " 1 | \n",
- " 0.000747 | \n",
- " 0.000065 | \n",
- " 0.000558 | \n",
+ " 15 | \n",
+ " 0.000187 | \n",
+ " 0.000609 | \n",
+ " 0.002844 | \n",
"
\n",
" \n",
"\n",
@@ -5489,12 +2737,12 @@
],
"text/plain": [
" TKAN GRU LSTM\n",
- "15 0.000215 0.000578 0.003284\n",
- "12 0.000148 0.001978 0.006539\n",
- "9 0.000131 0.002142 0.003162\n",
- "6 0.000199 0.001151 0.002695\n",
- "3 0.000094 0.000112 0.001722\n",
- "1 0.000747 0.000065 0.000558"
+ "1 0.000581 0.000340 0.000421\n",
+ "3 0.000316 0.000273 0.002282\n",
+ "6 0.000347 0.000707 0.003831\n",
+ "9 0.000237 0.002033 0.004908\n",
+ "12 0.000157 0.001278 0.004174\n",
+ "15 0.000187 0.000609 0.002844"
]
},
"metadata": {},
@@ -5535,40 +2783,40 @@
" \n",
" \n",
" \n",
- " 15 | \n",
- " 148.763068 | \n",
- " 42.478296 | \n",
- " 33.661120 | \n",
+ " 1 | \n",
+ " 76.861277 | \n",
+ " 26.988356 | \n",
+ " 18.895929 | \n",
"
\n",
" \n",
- " 12 | \n",
- " 131.653231 | \n",
- " 41.333987 | \n",
- " 31.695352 | \n",
+ " 3 | \n",
+ " 78.182759 | \n",
+ " 28.601693 | \n",
+ " 20.284139 | \n",
"
\n",
" \n",
- " 9 | \n",
- " 112.193632 | \n",
- " 36.144233 | \n",
- " 24.899039 | \n",
+ " 6 | \n",
+ " 77.580002 | \n",
+ " 28.549838 | \n",
+ " 19.617023 | \n",
"
\n",
" \n",
- " 6 | \n",
- " 97.351477 | \n",
- " 39.048632 | \n",
- " 26.407373 | \n",
+ " 9 | \n",
+ " 84.431719 | \n",
+ " 28.048108 | \n",
+ " 19.591124 | \n",
"
\n",
" \n",
- " 3 | \n",
- " 88.306836 | \n",
- " 34.379492 | \n",
- " 26.078834 | \n",
+ " 12 | \n",
+ " 107.920725 | \n",
+ " 31.686749 | \n",
+ " 22.577694 | \n",
"
\n",
" \n",
- " 1 | \n",
- " 95.677227 | \n",
- " 36.256652 | \n",
- " 24.907640 | \n",
+ " 15 | \n",
+ " 126.935536 | \n",
+ " 35.467582 | \n",
+ " 27.794554 | \n",
"
\n",
" \n",
"\n",
@@ -5576,12 +2824,12 @@
],
"text/plain": [
" TKAN GRU LSTM\n",
- "15 148.763068 42.478296 33.661120\n",
- "12 131.653231 41.333987 31.695352\n",
- "9 112.193632 36.144233 24.899039\n",
- "6 97.351477 39.048632 26.407373\n",
- "3 88.306836 34.379492 26.078834\n",
- "1 95.677227 36.256652 24.907640"
+ "1 76.861277 26.988356 18.895929\n",
+ "3 78.182759 28.601693 20.284139\n",
+ "6 77.580002 28.549838 19.617023\n",
+ "9 84.431719 28.048108 19.591124\n",
+ "12 107.920725 31.686749 22.577694\n",
+ "15 126.935536 35.467582 27.794554"
]
},
"metadata": {},
@@ -5615,40 +2863,40 @@
" \n",
" \n",
" \n",
- " 15 | \n",
- " 14.674705 | \n",
- " 2.705071 | \n",
- " 1.335022 | \n",
+ " 1 | \n",
+ " 11.082262 | \n",
+ " 3.945290 | \n",
+ " 1.723379 | \n",
"
\n",
" \n",
- " 12 | \n",
- " 6.613783 | \n",
- " 4.486661 | \n",
- " 3.014970 | \n",
+ " 3 | \n",
+ " 10.159973 | \n",
+ " 3.052013 | \n",
+ " 1.950043 | \n",
"
\n",
" \n",
- " 9 | \n",
- " 16.354742 | \n",
- " 4.689843 | \n",
- " 1.760482 | \n",
+ " 6 | \n",
+ " 8.965218 | \n",
+ " 1.885116 | \n",
+ " 2.320250 | \n",
"
\n",
" \n",
- " 6 | \n",
- " 13.005265 | \n",
- " 3.473911 | \n",
- " 3.841358 | \n",
+ " 9 | \n",
+ " 14.904037 | \n",
+ " 4.894200 | \n",
+ " 0.870263 | \n",
"
\n",
" \n",
- " 3 | \n",
- " 8.117994 | \n",
- " 2.375397 | \n",
- " 1.118862 | \n",
+ " 12 | \n",
+ " 14.096107 | \n",
+ " 3.808997 | \n",
+ " 1.490844 | \n",
"
\n",
" \n",
- " 1 | \n",
- " 19.594882 | \n",
- " 5.669121 | \n",
- " 1.037579 | \n",
+ " 15 | \n",
+ " 18.705339 | \n",
+ " 3.458216 | \n",
+ " 2.267511 | \n",
"
\n",
" \n",
"\n",
@@ -5656,12 +2904,12 @@
],
"text/plain": [
" TKAN GRU LSTM\n",
- "15 14.674705 2.705071 1.335022\n",
- "12 6.613783 4.486661 3.014970\n",
- "9 16.354742 4.689843 1.760482\n",
- "6 13.005265 3.473911 3.841358\n",
- "3 8.117994 2.375397 1.118862\n",
- "1 19.594882 5.669121 1.037579"
+ "1 11.082262 3.945290 1.723379\n",
+ "3 10.159973 3.052013 1.950043\n",
+ "6 8.965218 1.885116 2.320250\n",
+ "9 14.904037 4.894200 0.870263\n",
+ "12 14.096107 3.808997 1.490844\n",
+ "15 18.705339 3.458216 2.267511"
]
},
"metadata": {},
@@ -5669,7 +2917,7 @@
}
],
"source": [
- "n_aheads = [1, 3, 6, 9, 12, 15][::-1]\n",
+ "n_aheads = [1, 3, 6, 9, 12, 15]\n",
"models = [\n",
" \"TKAN\",\n",
" \"GRU\",\n",
@@ -5685,12 +2933,12 @@
" \n",
" for model_id in models:\n",
" \n",
- " for run in range(5):\n",
+ " for run in range(10):\n",
"\n",
" if model_id == 'TKAN':\n",
" model = Sequential([\n",
" Input(shape=X_train.shape[1:]),\n",
- " TKAN(100, sub_kan_output_dim = 20, sub_kan_input_dim = 20, return_sequences=True),\n",
+ " TKAN(100, return_sequences=True),\n",
" TKAN(100, sub_kan_output_dim = 20, sub_kan_input_dim = 20, return_sequences=False),\n",
" Dense(units=n_ahead, activation='linear')\n",
" ], name = model_id)\n",
@@ -5712,7 +2960,7 @@
" raise ValueError\n",
" \n",
" optimizer = keras.optimizers.Adam(0.001)\n",
- " model.compile(optimizer=optimizer, loss='mean_squared_error')\n",
+ " model.compile(optimizer=optimizer, loss='mean_squared_error', jit_compile=True)\n",
" if run==0:\n",
" model.summary()\n",
" \n",
@@ -5732,26 +2980,18 @@
" del model\n",
" del optimizer\n",
" \n",
- " \n",
- " print('R2 scores')\n",
- " print('Means:')\n",
- " display(pd.DataFrame({model_id: {n_ahead: np.mean(results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results.keys()}))\n",
- " display(pd.DataFrame({model_id: {n_ahead: np.mean(results_rmse[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results_rmse.keys()}))\n",
- " print('Std:')\n",
- " display(pd.DataFrame({model_id: {n_ahead: np.std(results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results.keys()}))\n",
- " display(pd.DataFrame({model_id: {n_ahead: np.std(results_rmse[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results_rmse.keys()}))\n",
- " print('Training Times')\n",
- " display(pd.DataFrame({model_id: {n_ahead: np.mean(time_results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in time_results.keys()}))\n",
- " display(pd.DataFrame({model_id: {n_ahead: np.std(time_results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in time_results.keys()}))"
+ "\n",
+ "print('R2 scores')\n",
+ "print('Means:')\n",
+ "display(pd.DataFrame({model_id: {n_ahead: np.mean(results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results.keys()}))\n",
+ "display(pd.DataFrame({model_id: {n_ahead: np.mean(results_rmse[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results_rmse.keys()}))\n",
+ "print('Std:')\n",
+ "display(pd.DataFrame({model_id: {n_ahead: np.std(results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results.keys()}))\n",
+ "display(pd.DataFrame({model_id: {n_ahead: np.std(results_rmse[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in results_rmse.keys()}))\n",
+ "print('Training Times')\n",
+ "display(pd.DataFrame({model_id: {n_ahead: np.mean(time_results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in time_results.keys()}))\n",
+ "display(pd.DataFrame({model_id: {n_ahead: np.std(time_results[model_id][n_ahead]) for n_ahead in n_aheads} for model_id in time_results.keys()}))"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8ffccbbf-8ca4-4166-97bc-24c0f5ca02d0",
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
@@ -5770,7 +3010,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.2"
+ "version": "3.11.9"
}
},
"nbformat": 4,