diff --git a/notebooks/Kalman_Filter_Gradient.ipynb b/notebooks/Kalman_Filter_Gradient.ipynb new file mode 100644 index 00000000..bec60971 --- /dev/null +++ b/notebooks/Kalman_Filter_Gradient.ipynb @@ -0,0 +1,1341 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "69ae14a1", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "90979a41", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING (pytensor.configdefaults): g++ not available, if using conda: `conda install gxx`\n", + "WARNING (pytensor.configdefaults): g++ not detected! PyTensor will be unable to compile C-implementations and will default to Python. Performance may be severely degraded. To remove this warning, set PyTensor flags cxx to an empty string.\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import pytensor\n", + "import pytensor.tensor as pt\n", + "import matplotlib.pyplot as plt\n", + "from pytensor.compile.builders import OpFromGraph\n", + "from time import perf_counter\n", + "from collections import defaultdict\n", + "import pymc_extras as pmx\n", + "from pymc_extras.statespace import structural as sts\n", + "import pytensor\n", + "from pytensor.graph.basic import explicit_graph_inputs\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "id": "a0d008fc", + "metadata": {}, + "source": [ + "### Generate a random dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fdb156d6", + "metadata": {}, + "outputs": [], + "source": [ + "mod = (\n", + " sts.LevelTrendComponent(order=2, innovations_order=[0, 1], name='level') +\n", + " sts.AutoregressiveComponent(order=1, name='ar') +\n", + " sts.MeasurementError(name='obs_error')\n", + ").build(verbose = False)\n", + "\n", + "param_values = {\n", + " 'initial_level': np.array([10, 0.1]),\n", + " 'sigma_level': np.array([1e-2]),\n", + " 'params_ar': np.array([0.95]),\n", + " 'sigma_ar': np.array(1e-2),\n", + " 'sigma_obs_error': np.array(1e-2),\n", + "}\n", + "\n", + "data_fn = pmx.statespace.compile_statespace(mod, steps=100)\n", + "hidden_state_data, obs_data = data_fn(**param_values)\n", + "\n", + "matrices = mod._unpack_statespace_with_placeholders()\n", + "\n", + "matrix_fn = pytensor.function(list(explicit_graph_inputs(matrices)),\n", + " matrices)\n", + "a0, P0, c, d, T, Z, R, H, Q = matrix_fn(**param_values, initial_state_cov=np.eye(mod.k_states))" + ] + }, + { + "cell_type": "markdown", + "id": "51b7e885", + "metadata": {}, + "source": [ + "### Symbolic variable" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3661408d", + "metadata": {}, + "outputs": [], + "source": [ + "# Paramètres symboliques\n", + "A_sym = pt.matrix(\"A\") # (n, n)\n", + "H_sym = pt.matrix(\"H\") # (n, n)\n", + "Q_sym = pt.matrix(\"Q\") # (n, n)\n", + "R_sym = pt.matrix(\"R\") # (n, n)\n", + "T_sym = pt.matrix(\"T\") # (n, n)\n", + "Z_sym = pt.matrix(\"Z\") # (n, n)\n", + "y_sym = pt.matrix(\"y\") # (T, n) : observations\n", + "\n", + "a0_sym = pt.vector(\"a0\") # (n,) \n", + "P0_sym = pt.matrix(\"P0\") # (n, n)\n", + "\n", + "data_sym = pt.matrix('data_sym') # [T, obs_dim]" + ] + }, + { + "cell_type": "markdown", + "id": "19e6a32d", + "metadata": {}, + "source": [ + "## Kalman filter with classic gradient" + ] + }, + { + "cell_type": "markdown", + "id": "4fb4cef1", + "metadata": {}, + "source": [ + "### The Loss\n", + "\n", + "The Negative Log-Likelihood loss os given in the paper as the following expression :\n", + "\n", + "$$\n", + "L_{NLL} = \\sum l_{n|n} + l_{n|n-1}\n", + "$$\n", + "\n", + "Where :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&l_{n|n} = 0 \\\\\n", + "&l_{n|n-1} = log det(F) + v_n^TFv_n\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "35351096", + "metadata": {}, + "outputs": [], + "source": [ + "def predict(a, P, T, Q):\n", + " a_hat = T @ a # x_n|n-1\n", + " P_hat = T @ P @ T.T + Q # P_n|n-1\n", + " return a_hat, P_hat\n", + "\n", + "def update(y, a, P, Z, H):\n", + " v = y - Z.dot(a) # z_n\n", + " PZT = P.dot(Z.T) \n", + "\n", + " F = Z.dot(PZT) + H # S_n\n", + " F_inv = pt.linalg.inv(F) # S_n^(-1)\n", + " K = PZT.dot(F_inv) # K_n\n", + "\n", + " I_KZ = pt.eye(a.shape[0]) - K.dot(Z)\n", + " a_filtered = a + K.dot(v) # x_n|n\n", + " P_filtered = I_KZ @ P # P_n|n\n", + "\n", + " inner_term = v.T @ F_inv @ v\n", + " _, F_logdet = pt.linalg.slogdet(F) # log det S_n\n", + " ll = (F_logdet + inner_term).ravel()[0] # Loss\n", + "\n", + " return [a_filtered, P_filtered, Z.dot(a), F, ll]\n", + "\n", + "def kalman_step(y, a, P, T, Z, H, Q):\n", + " a_filtered, P_filtered, obs_mu, obs_cov, ll = update(y=y, a=a, P=P, Z=Z, H=H)\n", + " a_hat, P_hat = predict(a=a_filtered, P=P_filtered, T=T, Q=Q)\n", + " return [a_filtered, a_hat, obs_mu, P_filtered, P_hat, obs_cov, ll]\n", + "\n", + "\n", + "outputs_info = [None, a0_sym, None, None, P0_sym, None, None]\n", + "\n", + "results_seq, updates = pytensor.scan(\n", + " kalman_step,\n", + " sequences=[data_sym],\n", + " outputs_info=outputs_info,\n", + " non_sequences=[T_sym, Z_sym, H_sym, Q_sym],\n", + " strict=False,\n", + ")\n", + "\n", + "# --- Loss ---\n", + "a_upd_seq, a_pred_seq, y_hat_seq, P_upd_seq, P_pred_seq, obs_cov, ll_seq = results_seq\n", + "loss = pt.sum(ll_seq)" + ] + }, + { + "cell_type": "markdown", + "id": "ece2f47e", + "metadata": {}, + "source": [ + "## Custom gradient" + ] + }, + { + "cell_type": "markdown", + "id": "5dc91ae7", + "metadata": {}, + "source": [ + "### Gradient with respect to **$a_{n-1|n-1}$**\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{da_{n-1|n-1}} = T_n^T \\frac{dL}{da_{n|n-1}} \n", + "+ \\frac{dl_{n-1|n-1}}{da_{n-1|n-1}} \\quad &\\text{(equation 22)} \\\\\n", + "&\\frac{dl_{n|n}}{da_{n|n}} = 0 \\quad &\\text{(equation 28)}\n", + "\\end{align}\n", + "$$\n", + "\n", + "Givent this two equations, we now have :\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{da_{n-1|n-1}} = T_n^T \\frac{dL}{da_{n|n-1}} \n", + "\\end{align}\n", + "$$\n", + "\n", + "### Gradient with respect to **$P_{n-1|n-1}$**\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dP_{n-1|n-1}} = T_n^T \\frac{dL}{dP_{n|n-1}} T_n\n", + "+ \\frac{dl_{n-1|n-1}}{dP_{n-1|n-1}} \\quad &\\text{(equation 23)} \\\\\n", + "&\\frac{dl_{n|n}}{dP_{n|n}} = 0 \\quad &\\text{(equation 28)}\n", + "\\end{align}\n", + "$$\n", + "\n", + "Givent this two equations, we now have :\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dP_{n-1|n-1}} = T_n^T \\frac{dL}{dP_{n|n-1}} T_n\n", + "\\end{align}\n", + "$$\n" + ] + }, + { + "cell_type": "markdown", + "id": "22a3560b", + "metadata": {}, + "source": [ + "### Gradient with respect to **$a_{n|n-1}$**\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{da_{n|n-1}} = (I - K_n Z_n)^T \\frac{dL}{da_{n|n}} + \\frac{dl_{n|n-1}}{da_{n|n-1}} \\quad &\\text{(equation 20)} \\\\\n", + "&\\frac{dl_{n|n-1}}{da_{n|n-1}} = -2 Z_n^{T}F_n^{-1} v_n \\quad &\\text{(equation 30)} \\\\\n", + "&\\frac{dL}{da_{n|n}} = T_n^T \\frac{dL}{da_{n+1|n}} \\quad &\\text{see gradient with respect to} \\quad a_{n|n}\n", + "\\end{align}\n", + "$$\n", + "\n", + "Givent this two equations, we now have :\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{da_{n|n-1}} = (I - K_n Z_n)^T T_n^T \\frac{dL}{da_{n+1|n}} - 2 Z_n^{T}F^{-1} v_n\n", + "\\end{align}\n", + "$$\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ee21ef4e", + "metadata": {}, + "outputs": [], + "source": [ + "def grad_a_hat(inp, out, out_grad):\n", + " y, a, P, T, Z, H, Q = inp\n", + " a_hat_grad, _, _ = out_grad\n", + "\n", + " v = y - Z.dot(a) \n", + "\n", + " PZT = P.dot(Z.T)\n", + " F = Z.dot(PZT) + H \n", + " F_inv = pt.linalg.inv(F)\n", + " \n", + " K = PZT.dot(F_inv) \n", + " I_KZ = pt.eye(a.shape[0]) - K.dot(Z)\n", + "\n", + " grad_a_pred = I_KZ.T @ T.T @ a_hat_grad - 2 * Z.T @ F_inv @ v\n", + "\n", + " return grad_a_pred" + ] + }, + { + "cell_type": "markdown", + "id": "293d8d65", + "metadata": {}, + "source": [ + "### Gradient with respect to **$P_{n|n-1}$**\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dP_{n|n-1}} = (I - K_n Z_n)^T [\n", + " \\frac{dL}{dP_{n|n}}\n", + " + \\frac{1}{2} \\frac{dL}{da_{n|n}} v_n^T H_n^-1 Z_n\n", + " + \\frac{1}{2} Z_n^T R_n^{-1} v_n (\\frac{dL}{da_{n|n}})^T\n", + " ](I - K_n Z_n) \n", + " + \\frac{dl{n|n-1}}{dP_{n|n-1}} \\quad &\\text{(equation 21)} \\\\\n", + "&\\frac{dl_{n|n-1}}{dP_{n|n-1}} = Z_n^T F_n^{-1} Z_n - Z_n^T F_n^-1 v_n v_n^T F_n^{-1} Z_n \\quad &\\text{(equation 29)} \\\\\n", + "&\\frac{dL}{da_{n|n}} = T_n^T \\frac{dL}{da_{n+1|n}} \\quad &\\text{see gradient with respect to} \\quad a_{n|n} \\\\\n", + "&\\frac{dL}{dP_{n|n}} = T_n^T \\frac{dL}{dP_{n+1|n}} T_n \\quad &\\text{see gradient with respect to} \\quad P_{n|n}\n", + "\\end{align}\n", + "$$\n", + "\n", + "Givent this two equations, we now have :\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dP_{n|n-1}} = (I - K_n Z_n)^T [\n", + " T_n^T \\frac{dL}{dP_{n+1|n}} T_n\n", + " + \\frac{1}{2} T_n^T \\frac{dL}{da_{n+1|n}} v_n^T H_n^{-1} Z_n\n", + " + \\frac{1}{2} Z_n^T H_n^{-1} v_n (T_n^T \\frac{dL}{da_{n+1|n}})^T\n", + " ](I - K_n Z_n) \n", + " + Z_n^T F_n^{-1} Z_n \n", + " - Z_n^T F_n^{-1} v_n v_n^T F_n^{-1} Z_n\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8c89b018", + "metadata": {}, + "outputs": [], + "source": [ + "def grad_P_hat(inp, out, out_grad):\n", + " y, a, P, T, Z, H, Q = inp\n", + " a_hat_grad, P_hat_grad, ll_grad = out_grad\n", + "\n", + " v = y - Z.dot(a)\n", + " v = v.dimshuffle(0, 'x')\n", + " a_hat_grad = a_hat_grad.dimshuffle(0, 'x') \n", + "\n", + " P_filtered_grad = T.T @ P_hat_grad @ T\n", + " a_filtered_grad = T.T @ a_hat_grad \n", + "\n", + " PZT = P.dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + "\n", + " H_inv = pt.linalg.inv(H) \n", + " F_inv = pt.linalg.inv(F)\n", + " \n", + " K = PZT.dot(F_inv) \n", + " I_KZ = pt.eye(a.shape[0]) - K.dot(Z)\n", + "\n", + " grad_P_hat = I_KZ.T @ ( P_filtered_grad + 0.5 * a_filtered_grad @ v.T @ H_inv @ Z + 0.5 * Z.T @ H_inv @ v @ a_filtered_grad.T ) @ I_KZ + Z.T @ F_inv @ Z - Z.T @ F_inv @ v @ v.T @ F_inv @ Z\n", + "\n", + " return grad_P_hat" + ] + }, + { + "cell_type": "markdown", + "id": "f0f2dce4", + "metadata": {}, + "source": [ + "### Gradient with respect to **y**\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dy_n} = K_n^T\\frac{dL}{da_{n|n}} + \\frac{dl_{n|n-1}}{dy_n} \\quad &\\text{(equation 24)} \\\\\n", + "&\\frac{dl_{n|n-1}}{dy_n} = 2F^{-1}v_n \\quad &\\text{(equation 31)} \\\\\n", + "&\\frac{dL}{da_{n|n}} = T_n^T \\frac{dL}{da_{n+1|n}} \\quad &\\text{see gradient with respect to} \\quad a_{n|n} \\\\\n", + "\\end{align}\n", + "$$\n", + "\n", + "Givent this two equations, we now have :\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dy_n} = K_n^TT_n^T\\frac{dL}{da_{n+1|n}} + 2F^{-1}v_n\n", + "\\end{align}\n", + "$$\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "bba53a26", + "metadata": {}, + "outputs": [], + "source": [ + "def grad_y(inp, out, out_grad):\n", + " y, a, P, T, Z, H, Q = inp\n", + " a_hat_grad, P_h_grad, y_grad = out_grad\n", + "\n", + " y_hat = Z.dot(a)\n", + " v = y - y_hat\n", + "\n", + " PZT = P.dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + " F_inv = pt.linalg.inv(F)\n", + "\n", + " K = PZT.dot(F_inv) \n", + " \n", + " return K.T @ T.T @ a_hat_grad + 2 * F_inv @ v" + ] + }, + { + "cell_type": "markdown", + "id": "d6b48789", + "metadata": {}, + "source": [ + "### Gradient with respect to Q\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "\\frac{dL}{dQ_n} = \\frac{dL}{dP_{n|n-1}} & \\quad \\text{(equation 25)}\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c17949b7", + "metadata": {}, + "outputs": [], + "source": [ + "def grad_Q(inp, out, out_grad):\n", + " _, P_h_grad, _ = out_grad\n", + " return P_h_grad" + ] + }, + { + "cell_type": "markdown", + "id": "f0bc0287", + "metadata": {}, + "source": [ + "### Gradient with respect to **H**\n", + "\n", + "From the article we have :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dH_n} = K_n^T\\frac{dL}{dP_{n|n}}K_n \n", + "- \\frac{1}{2} K_n^T \\frac{dL}{da_{n|n}} v_n^T F^{-1}\n", + "- \\frac{1}{2} S_n^{-1} v_n (\\frac{dL}{da_{n|n}})^T K_n\n", + "+ \\frac{dl_{n|n-1}}{dH_n} \n", + "\\quad &\\text{(equation 26)} \\\\\n", + "&\\frac{dl_{n|n-1}}{dH_n} = F^{-1} - F_n^{-1} v_n v_n^T F_n^{-1} \n", + "\\quad &\\text{(equation 31)} \\\\\n", + "&\\frac{dL}{da_{n|n}} = T_n^T \\frac{dL}{da_{n+1|n}} \\quad &\\text{see gradient with respect to} \\quad a_{n|n} \\\\\n", + "&\\frac{dL}{dP_{n|n}} = T_n^T \\frac{dL}{dP_{n+1|n}} T_n \\quad &\\text{see gradient with respect to} \\quad P_{n|n}\n", + "\\end{align}\n", + "$$\n", + "\n", + "Givent this two equations, we now have :\n", + "$$\n", + "\\begin{align}\n", + "&\\frac{dL}{dH_n} = K_n^T T_n^T \\frac{dL}{dP_{n+1|n}} T_n K_n \n", + "- \\frac{1}{2} K_n^T T_n^T \\frac{dL}{da_{n+1|n}} v_n^T F^{-1}\n", + "- \\frac{1}{2} F_n^{-1} v_n (T_n^T \\frac{dL}{da_{n+1|n}})^T K_n\n", + "+ F^{-1} - F_n^{-1} v_n v_n^T F_n^{-1}\n", + "\\end{align}\n", + "$$\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "84cb6867", + "metadata": {}, + "outputs": [], + "source": [ + "def grad_H(inp, out, out_grad):\n", + " y, a, P, T, Z, H, Q = inp\n", + " a_hat_grad, P_h_grad, y_grad = out_grad\n", + " \n", + " y_hat = Z.dot(a)\n", + " v = y - y_hat\n", + "\n", + " PZT = P.dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + " F_inv = pt.linalg.inv(F)\n", + "\n", + " K = PZT.dot(F_inv)\n", + "\n", + " v = v.dimshuffle(0, 'x')\n", + " a_hat_grad = a_hat_grad.dimshuffle(0, 'x') \n", + "\n", + " a_filtered_grad = T.T @ a_hat_grad\n", + " P_filtered_grad = T.T @ P_h_grad @ T\n", + "\n", + " return K.T @ P_filtered_grad @ K - 0.5 * K.T @ a_filtered_grad @ v.T @ F_inv - 0.5 * F_inv @ v @ a_filtered_grad.T @ K + F_inv - F_inv @ v @ v.T @ F_inv" + ] + }, + { + "cell_type": "markdown", + "id": "4fa2ffc0", + "metadata": {}, + "source": [ + "### Gradient with respect to **T**\n", + "\n", + "This gradient was not given in the article. Here are the steps that got me to this expression :\n", + "\n", + "1 - Only $x_{n|n-1}$ and $P_{n|n-1}$ depends on $T_n$. Hence :\n", + "$$\n", + "\\frac{\\partial L}{\\partial T} = \\frac{\\partial L}{\\partial x_{n|n-1}} \\frac{\\partial x_{n|n-1}}{\\partial T} + \\frac{\\partial L}{\\partial P_{n|n-1}} \\frac{\\partial T}{\\partial P_{n|n-1}}\n", + "$$\n", + "2 - Using the equation (11) and (12) of the article, on the (1), we directly got that :\n", + "$$\n", + "\\frac{\\partial L}{\\partial x_{n|n-1}} \\frac{\\partial x_{n|n-1}}{\\partial T} = \\frac{\\partial L}{\\partial x_{n|n-1}} x_{n-1|n-1}^T\n", + "$$\n", + "3 - Recognizing the first quadratic form in the equation (2), and using equation (11) we got :\n", + "$$\n", + "\\frac{\\partial L}{\\partial P_{n|n-1}} \\frac{\\partial P_{n|n-1}}{\\partial T^T} = P_{n|n-1}T_n^T \\frac{\\partial L}{\\partial P_{n|n-1}}^T + P_{n|n-1}^T T_n^T \\frac{\\partial L}{\\partial P_{n|n-1}}\n", + "$$\n", + "4 - Now transposing to get the dependencies on T :\n", + "$$\n", + "\\frac{\\partial L}{\\partial P_{n|n-1}} \\frac{\\partial P_{n|n-1}}{\\partial T} = \\frac{\\partial L}{\\partial P_{n|n-1}} T_n P_{n|n-1}^T +\\frac{\\partial L}{\\partial P_{n|n-1}}^T T_n P_{n|n-1}\n", + "$$\n", + "5 - Finally, we have :\n", + "$$\n", + "\\frac{\\partial L}{\\partial T} = \\frac{\\partial L}{\\partial x_{n|n-1}} x_{n-1|n-1}^T + \\frac{\\partial L}{\\partial P_{n|n-1}} T_n P_{n|n-1}^T +\\frac{\\partial L}{\\partial P_{n|n-1}}^T T_n P_{n|n-1}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9a560ed9", + "metadata": {}, + "outputs": [], + "source": [ + "def grad_T(inp, out, out_grad):\n", + " y, a, P, T, Z, H, Q = inp\n", + " a_hat_grad, P_h_grad, y_grad = out_grad\n", + "\n", + " y_hat = Z.dot(a)\n", + " v = y - y_hat\n", + "\n", + " PZT = P.dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + " F_inv = pt.linalg.inv(F)\n", + "\n", + " K = PZT.dot(F_inv)\n", + " I_KZ = pt.eye(a.shape[0]) - K.dot(Z)\n", + "\n", + " v = v.dimshuffle(0, 'x')\n", + " a = a.dimshuffle(0, 'x')\n", + " a_hat_grad = a_hat_grad.dimshuffle(0, 'x')\n", + "\n", + " a_filtered = a + K.dot(v)\n", + " P_filtered = I_KZ @ P\n", + "\n", + " return a_hat_grad @ a_filtered.T + P_h_grad @ T @ P_filtered.T + P_h_grad.T @ T @ P_filtered" + ] + }, + { + "cell_type": "markdown", + "id": "2666eaec", + "metadata": {}, + "source": [ + "### Gradient with respect to Z\n", + "\n", + "To obtain this gradient, I used the matrix differential + trace trick. So we consider that $Z_n$ influences the loss directly through $v_n$ and $F_n$, and indirectly through backpropagation through $P_{n|n}$ and $a_{n|n}$. Then $\\frac{dL}{dZ_n}$ is the matrix that verify :\n", + "\n", + "$$\n", + "dL = tr((\\frac{dL}{dv_n})^T dv_n + (\\frac{dL}{dF_n})^T dF_n + (\\frac{dL}{dP_{n|n}})^T dP_{n|n} + (\\frac{dL}{da_{n|n}})^T da_{n|n}) = tr((\\frac{dL}{dZ_n})^TdZ_n)\n", + "$$\n", + "\n", + "It is worth also writing that we'll use simple trace tricks :\n", + "- **transpose invariance** : $tr(A) = tr(A^T)$ ; \n", + "- **ciclicity of the trace** : $tr(ABC) = tr(CAB) = tr(BCA)$. " + ] + }, + { + "cell_type": "markdown", + "id": "65986a5b", + "metadata": {}, + "source": [ + "#### **First term** : $(\\frac{dL}{dv_n})^T dv_n$\n", + "\n", + "First noticing that :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "\\frac{dL}{dy_n} &= \\frac{dL}{dv_n}\\frac{dv_n}{dy_n} = \\frac{dL}{dv_n} I = \\frac{dL}{dv_n} \\\\\n", + "dv_n &= - dZ_n a_{n|n-1}\n", + "\\end{align}\n", + "$$\n", + "\n", + "And since we directly got from equation (31) of the paper that :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "\\frac{dL}{dy_n} = \\frac{dl_{n|n-1}}{dy_n} = 2 F_n^{-1} v_n\n", + "\\end{align}\n", + "$$\n", + "\n", + "We ultimately have that :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "(\\frac{dL}{dv_n})^T dv_n &= ( 2 F_n^{-1} v_n )^T (- dZ_n a_{n|n-1}) = - 2 v_n^T F_n^{-T} dZ_n a_{n|n-1}\n", + "\\end{align}\n", + "$$\n", + "\n", + "Now, using trace tricks :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "tr((\\frac{dL}{dv_n})^T dv_n) = tr( - 2 F_n^{-1} v_n dZ_n a_{n|n-1} ) = tr( (- 2 F_n^{-1} v_n a_{n|n-1}^T)^T dZ_n)\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "f7ae123c", + "metadata": {}, + "source": [ + "#### **Second term** : $(\\frac{dL}{dF_n})^T dF_n$\n", + "\n", + "Starting with differentiations and derivatives :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "dF_n &= dZ_n P_{n|n-1} Z_n^T + Z_n P_{n|n-1} (dZ_n)^T \\\\\n", + "d(F_n^{-1}) &= -F_n^{-1} dF_n F_n^{-1} \\\\\n", + "\\frac{dL}{dF_n} &= F_n^{-1} - F_n^{-1} z_n z_n^T F_n^{-1}\\\\\n", + "\\end{align}\n", + "$$\n", + "\n", + "Now, using the trace trick :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "tr((\\frac{dL}{dF_n})^T dF_n) &= tr((F_n^{-1} - F_n^{-1} z_n z_n^T F_n^{-1})^T (dZ_n P_{n|n-1} Z_n^T + Z_n P_{n|n-1} (dZ_n)^T)) \\\\\n", + "&=tr(F_n^{-T} dZ_n P_{n|n-1} Z_n^T + F_n^{-T} Z_n P_{n|n-1} (dZ_n)^T - F_n^{-T} z_n z_n^T F_n^{-T} dZ_n P_{n|n-1} Z_n^T - F_n^{-T} z_n z_n^T F_n^{-T} Z_n P_{n|n-1} (dZ_n)^T) \\\\\n", + "&=tr(P_{n|n-1} Z_n^T F_n^{-T} dZ_n + P_{n|n-1}^T Z_n^T F_n^{-1} dZ_n - P_{n|n-1} Z_n^T F_n^{-T} z_n z_n^T F_n^{-T} dZ_n - P_{n|n-1} F_n^{-1} z_n z_n^T F_n^{-1} dZ_n) \\\\\n", + "&=tr((F_n^{-1} Z_n P_{n|n-1}^T + F_n^{-T} Z_n P_{n|n-1} - F_n^{-1} z_n z_n^T F_n^{-1} Z_n P_{n|n-1}^T - F_n^{-T} z_n z_n^T F_n^{-T} Z_n P_{n|n-1}^T)^T dZ_n)\n", + "\\end{align}\n", + "$$\n", + "\n", + "Noticing that $P_{n|n-1}^T = P_{n|n-1}$ and $F_n^T = F_n$ :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "tr((\\frac{dL}{dF_n})^T dF_n) &= tr((2 F_n^{-1} Z_n P_{n|n-1} - 2 F_n^{-1} z_n z_n^T F_n^{-1} Z_n P_{n|n-1})^T dZ_n)\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "4d4c8cd8", + "metadata": {}, + "source": [ + "#### **Third term** : $(\\frac{dL}{dP_{n|n}})^T dP_{n|n}$\n", + "\n", + "Starting with differencciations :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "dP_{n|n} &= - (dK_n Z_n + K_n dZ_n) P_{n|n-1} \\\\\n", + "&= - (P_{n|n-1}((dZ_n)^T S_n^{-1} + Z_n^T d(S_n^{-1})) Z_n + K_n dZ_n) P_{n|n-1} \\\\\n", + "&= - (P_{n|n-1}((dZ_n)^T S_n^{-1} - Z_n^T S_n^{-1} dS_n S_n^{-1}) Z_n + K_n dZ_n) P_{n|n-1} \\\\\n", + "&= - (P_{n|n-1}((dZ_n)^T S_n^{-1} - Z_n^T S_n^{-1} (dZ_n P_{n|n-1} Z_n^T + Z_n P_{n|n-1} (dZ_n)^T) S_n^{-1}) Z_n + K_n dZ_n) P_{n|n-1} \\\\\n", + "&= - P_{n|n-1}(dZ_n)^T S_n^{-1} Z_n P_{n|n-1} \\\\\n", + "&+ P_{n|n-1} Z_n^T S_n^{-1} dZ_n P_{n|n-1} Z_n^T S_n^{-1} Z_n P_{n|n-1} \\\\\n", + "&+ P_{n|n-1} Z_n^T S_n^{-1} Z_n P_{n|n-1} (dZ_n)^T S_n^{-1} Z_n P_{n|n-1} \\\\\n", + "&- K_n dZ_n P_{n|n-1} \\\\\n", + "\\end{align}\n", + "$$\n", + "\n", + "Term by term, let's use the trace trick to get $dZ_n$ at the right place :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&tr((\\frac{dL}{dP_{n|n}})^T P_{n|n-1}(dZ_n)^T S_n^{-1} Z_n P_{n|n-1}) = tr(P_{n|n-1}^T Z_n^T S_n^{-T} dZ_n P_{n|n-1}^T \\frac{dL}{dP_{n|n}}) = tr(P_{n|n-1}^T \\frac{dL}{dP_{n|n}} P_{n|n-1}^T Z_n^T S_n^{-T} dZ_n) \\\\\n", + "&tr((\\frac{dL}{dP_{n|n}})^T P_{n|n-1} Z_n^T S_n^{-1} dZ_n P_{n|n-1} Z_n^T S_n^{-1} Z_n P_{n|n-1}) = tr(P_{n|n-1} Z_n^T S_n^{-1} Z_n P_{n|n-1} (\\frac{dL}{dP_{n|n}})^T P_{n|n-1} Z_n^T S_n^{-1} dZ_n) \\\\\n", + "&tr((\\frac{dL}{dP_{n|n}})^T P_{n|n-1} Z_n^T S_n^{-1} Z_n P_{n|n-1} (dZ_n)^T S_n^{-1} Z_n P_{n|n-1}) = tr(P_{n|n-1}^T Z_n^T S_n^{-T} dZ_n P_n{n|n-1}^T Z_n^T S_n^{-T} Z_n P_{n|n-1}^T \\frac{dL}{dP_{n|n}}) \\\\\n", + "&= tr(P_{n|n-1}^T Z_n^T S_n^{-T} Z_n P_{n|n-1}^T \\frac{dL}{dP_{n|n}} P_{n|n-1}^T Z_n^T S_n^{-T} dZ_n) \\\\\n", + "&tr((\\frac{dL}{dP_{n|n}})^T K_n dZ_n P_{n|n-1}) = tr(P_{n|n-1} (\\frac{dL}{dP_{n|n}})^T K_n dZ_n) \\\\\n", + "\\\\\n", + "\\end{align}\n", + "$$\n", + "\n", + "So the contibution of $Z_n$ through $P_{n|n}$ is :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&tr((\\frac{dL}{dP_{n|n}})^T dP_{n|n}) = tr((- S_n^{-1} Z_n P_{n|n-1} (\\frac{dL}{dP_{n|n}})^T P_{n|n-1} + K_n^T \\frac{dL}{dP_{n|n}} P_{n|n-1}^T Z_n^T K_n^T + S_n^{-1} Z_n P_{n|n-1} (\\frac{dL}{dP_{n|n}})^T K_n Z_n P_{n|n-1} - K_n^T \\frac{dL}{dP_{n|n}} P_{n|n-1}^T)^T K_n)^T dZ_n)\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "32c8125d", + "metadata": {}, + "source": [ + "#### **Fourth term** : $(\\frac{dL}{da_{n|n}})^T da_{n|n}$\n", + "\n", + "Starting with differenciations :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "da_{n|n} &= dK_n v_n + K_n dv_n \\\\\n", + "&= P_{n|n-1}((dZ_n)^T S_n^{-1} + Z_n^T d(S_n^{-1})) v_n - K_n dZ_n a_{n|n-1} \\\\\n", + "&= P_{n|n-1}((dZ_n)^T S_n^{-1} - Z_n^T S_n^{-1} dS_n S_n^{-1}) v_n - K_n dZ_n a_{n|n-1} \\\\\n", + "&= P_{n|n-1}((dZ_n)^T S_n^{-1} - Z_n^T S_n^{-1} (dZ_n P_{n|n-1} Z_n^T + Z_n P_{n|n-1} (dZ_n)^T) S_n^{-1}) v_n - K_n dZ_n a_{n|n-1} \\\\\n", + "&= P_{n|n-1} (dZ_n)^T S_n^{-1} v_n \\\\\n", + "&- P_{n|n-1} Z_n^T S_n^{-1} dZ_n P_{n|n-1} Z_n^T S_n^{-1} v_n \\\\\n", + "&- P_{n|n-1} Z_n^T S_n^{-1} Z_n P_{n|n-1} (dZ_n)^T S_n^{-1} v_n \\\\\n", + "&- K_n dZ_n a_{n|n-1} \\\\\n", + "\\end{align}\n", + "$$\n", + "\n", + "Term by term, let's use the trace trick to get $dZ_n$ at the right place :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "&tr((\\frac{dL}{da_{n|n}})^T P_{n|n-1} (dZ_n)^T S_n^{-1} v_n) = tr(v_n^T S_n^{-T} dZ_n P_{n|n-1}^T \\frac{dL}{da_{n|n}}) = tr(P_{n|n-1}^T \\frac{dL}{da_{n|n}} v_n^T S_n^{-T} dZ_n) \\\\\n", + "&tr((\\frac{dL}{da_{n|n}})^T P_{n|n-1} Z_n^T S_n^{-1} dZ_n P_{n|n-1} Z_n^T S_n^{-1} v_n) = tr(P_{n|n-1} Z_n^T S_n^{-1} v_n (\\frac{dL}{da_{n|n}})^T P_{n|n-1} Z_n^T S_n^{-1} dZ_n) \\\\\n", + "&tr((\\frac{dL}{da_{n|n}})^T P_{n|n-1} Z_n^T S_n^{-1} Z_n P_{n|n-1} (dZ_n)^T S_n^{-1} v_n) = tr(v_n^T S_n^{-T} dZ_n P_n{n|n-1}^T Z_n^T S_n^{-T} Z_n P_{n|n-1}^T \\frac{dL}{da_{n|n}}) \\\\\n", + "&= tr(P_n{n|n-1}^T Z_n^T S_n^{-T} Z_n P_{n|n-1}^T \\frac{dL}{a_{n|n}} v_n^T S_n^{-T} dZ_n) \\\\\n", + "&tr((\\frac{dL}{da_{n|n}})^T K_n dZ_n a_{n|n-1}) = tr(a_{n|n-1} (\\frac{dL}{dda_{n|n}})^T K_n dZ_n) \\\\\n", + "\\\\\n", + "\\end{align}\n", + "$$\n", + "\n", + "So the contibution of $Z_n$ through $a_{n|n}$ is :\n", + "$$\n", + "\\begin{align}\n", + "&tr((\\frac{dL}{da_{n|n}})^T da_{n|n}) = tr((S_n^{-1} v_n (\\frac{dL}{da_{n|n}})^T P_{n|n-1} - K_n^T \\frac{dL}{da_{n|n}} v_n^T K_n^T - S_n^{-1} v_n (\\frac{dL}{da_{n|n}})^T K_n Z_n P_{n|n-1} - K_n^T \\frac{dL}{da_{n|n}} a_{n|n-1}^T)^T dZ_n) \\\\\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "f9e35a41", + "metadata": {}, + "source": [ + "#### **To sum up** :\n", + "\n", + "In the end, the gradient of $L$ with respect to $Z_n$ is :\n", + "\n", + "$$\n", + "\\begin{align}\n", + "\\frac{dL}{dZ_n} &= - 2 F_n^{-1} v_n a_{n|n-1}^T\\\\\n", + "&+ 2 F_n^{-1} Z_n P_{n|n-1} - 2 F_n^{-1} z_n z_n^T F_n^{-1} Z_n P_{n|n-1}\\\\\n", + "&+ S_n^{-1} v_n (\\frac{dL}{da_{n|n}})^T P_{n|n-1} - K_n^T \\frac{dL}{da_{n|n}} v_n^T K_n^T - S_n^{-1} v_n (\\frac{dL}{da_{n|n}})^T K_n Z_n P_{n|n-1} - K_n^T \\frac{dL}{da_{n|n}} a_{n|n-1}^T \\\\\n", + "&- S_n^{-1} Z_n P_{n|n-1} (\\frac{dL}{dP_{n|n}})^T P_{n|n-1} + K_n^T \\frac{dL}{dP_{n|n}} P_{n|n-1}^T Z_n^T K_n^T + S_n^{-1} Z_n P_{n|n-1} (\\frac{dL}{dP_{n|n}})^T K_n Z_n P_{n|n-1} - K_n^T \\frac{dL}{dP_{n|n}} P_{n|n-1}^T \\\\\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "df925ee8", + "metadata": {}, + "outputs": [], + "source": [ + "def grad_Z(inp, out, out_grad):\n", + "\n", + " y, a, P, T, Z, H, Q = inp\n", + " a_h_grad, P_h_grad, y_grad = out_grad\n", + "\n", + " y_hat = Z.dot(a)\n", + " v = y - y_hat\n", + "\n", + " PZT = P.dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + " F_inv = pt.linalg.inv(F)\n", + " K = PZT.dot(F_inv)\n", + "\n", + " v = v.dimshuffle(0, 'x')\n", + " a = a.dimshuffle(0, 'x')\n", + " a_h_grad = a_h_grad.dimshuffle(0, 'x')\n", + "\n", + " a_filtered_grad = T.T @ a_h_grad\n", + " P_filtered_grad = T.T @ P_h_grad @ T\n", + "\n", + " # Contribution via Pnn\n", + "\n", + " term_P_1 = - F_inv @ Z @ P @ P_filtered_grad.T @ P\n", + " term_P_2 = K.T @ P_filtered_grad @ P.T @ Z.T @ K.T\n", + " term_P_3 = F_inv @ Z @ P @ P_filtered_grad.T @ K @ Z @ P\n", + " term_P_4 = - K.T @ P_filtered_grad @ P.T\n", + "\n", + " contrib_P = term_P_1 + term_P_2 + term_P_3 + term_P_4\n", + "\n", + " # Contibution via xnn\n", + "\n", + " term_x_1 = F_inv @ v @ a_filtered_grad.T @ P\n", + " term_x_2 = - K.T @ a_filtered_grad @ v.T @ K.T\n", + " term_x_3 = - F_inv @ v @ a_filtered_grad.T @ K @ Z @ P\n", + " term_x_4 = - K.T @ a_filtered_grad @ a.T\n", + "\n", + " contrib_x = term_x_1 + term_x_2 + term_x_3 + term_x_4\n", + "\n", + " # Contribution via Fn\n", + " \n", + " term_F_1 = 2 * F_inv @ Z @ P \n", + " term_F_2 = - 2 *F_inv @ v @ v.T @ F_inv @ Z @ P\n", + "\n", + " contrib_F = term_F_1 + term_F_2\n", + "\n", + " # Contribution via vn\n", + " contrib_v = - 2 * F_inv @ v @ a.T\n", + " \n", + " return contrib_x + contrib_P + contrib_F + contrib_v" + ] + }, + { + "cell_type": "markdown", + "id": "bd458dee", + "metadata": {}, + "source": [ + "### Total grad" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "afb362e5", + "metadata": {}, + "outputs": [], + "source": [ + "def custom_grad(inp, out, out_grad):\n", + "\n", + "\n", + " P_hat = grad_P_hat(inp, out, out_grad)\n", + " a_hat = grad_a_hat(inp, out, out_grad)\n", + " y = grad_y(inp, out, out_grad)\n", + " Z = grad_Z(inp, out, out_grad)\n", + " T = grad_T(inp, out, out_grad)\n", + " Q = grad_Q(inp, out, out_grad)\n", + " H = grad_H(inp, out, out_grad)\n", + "\n", + " return [y, a_hat, P_hat, T, Z, H, Q]" + ] + }, + { + "cell_type": "markdown", + "id": "607753a1", + "metadata": {}, + "source": [ + "## Custom Kalman Filter" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "7cead2c1", + "metadata": {}, + "outputs": [], + "source": [ + "y_sym = pt.vector(\"y\")\n", + "\n", + "kalman_step_op = OpFromGraph(\n", + " inputs=[y_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=kalman_step(y_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym),\n", + " lop_overrides=custom_grad,\n", + " inline=True\n", + ")\n", + "\n", + "outputs_info = [None, a0_sym, None, None, P0_sym, None, None]\n", + "\n", + "results_op, updates = pytensor.scan(\n", + " kalman_step_op,\n", + " sequences=[data_sym],\n", + " outputs_info=outputs_info,\n", + " non_sequences=[T_sym, Z_sym, H_sym, Q_sym],\n", + " strict=False,\n", + ")\n", + "# --- Loss ---\n", + "a_upd_op, a_pred_op, y_hat_op, P_upd_op, P_pred_op, obs_cov, ll_op = results_op\n", + "loss_op = pt.sum(ll_op)" + ] + }, + { + "cell_type": "markdown", + "id": "3f79c5c6", + "metadata": {}, + "source": [ + "## Handmade Numpy Backpropagation " + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b6eb5d48", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_grad_a0(observations, a0, P0, a_pred_seq, P_pred_seq, Z, H, T):\n", + " # Constant\n", + " SHAPE_a0 = a0.shape[0]\n", + " NB_obs = len(observations)\n", + "\n", + " # Initialisation for the backprop\n", + " PZT = P_pred_seq[-2].dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + " F_inv = np.linalg.solve(F, np.eye(F.shape[0]))\n", + " \n", + " grad = [0 for _ in range(NB_obs)]\n", + " grad[-1] = - 2 * Z.T @ F_inv @ (observations[-1] - Z @ a_pred_seq[-2])\n", + "\n", + " # Backprop\n", + " for i in range(3, NB_obs+1):\n", + "\n", + " PZT = P_pred_seq[-i].dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + " F_inv = np.linalg.solve(F, np.eye(F.shape[0]))\n", + "\n", + " K = PZT.dot(F_inv)\n", + " I_KZ = np.eye(SHAPE_a0) - K.dot(Z)\n", + "\n", + " grad[1-i] = I_KZ.T @ T.T @ grad[2-i] - (2 * Z.T @ F_inv @ (observations[1-i] - Z @ a_pred_seq[-i])).T \n", + "\n", + " # Last iter with a0/P0\n", + " PZT = P0.dot(Z.T)\n", + " F = Z.dot(PZT) + H\n", + " F_inv = np.linalg.solve(F, np.eye(F.shape[0]))\n", + "\n", + " K = PZT.dot(F_inv)\n", + " I_KZ = np.eye(SHAPE_a0) - K.dot(Z)\n", + "\n", + " grad[0] = I_KZ.T @ T.T @ grad[1] - (2 * Z.T @ F_inv @ (observations[0] - Z @ a0)).T\n", + "\n", + " return grad" + ] + }, + { + "cell_type": "markdown", + "id": "f0575c2c", + "metadata": {}, + "source": [ + "## Speed observation" + ] + }, + { + "cell_type": "markdown", + "id": "c99fddf9", + "metadata": {}, + "source": [ + "### Benchmark for pytensor computed gradients" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "908946b0", + "metadata": {}, + "outputs": [], + "source": [ + "def benchmark_kalman_gradients(loss, obs_data, a0, P0, T, Z, R, H, Q):\n", + " results = defaultdict(dict)\n", + " exec_time = 0\n", + "\n", + " grad_list = pt.grad(loss, [data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym])\n", + " f_grad = pytensor.function(\n", + " inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=grad_list,\n", + " )\n", + "\n", + " for _ in range(20):\n", + " \n", + " # --- exécution ---\n", + " t0 = perf_counter()\n", + " _ = f_grad(\n", + " obs_data[:, np.newaxis],\n", + " a0,\n", + " P0,\n", + " T,\n", + " Z,\n", + " H,\n", + " R @ Q @ R.T,\n", + " )\n", + " t1 = perf_counter()\n", + " exec_time += (t1 - t0)/20\n", + " \n", + " \n", + " results[\"exec_time\"] = exec_time\n", + "\n", + " return results" + ] + }, + { + "cell_type": "markdown", + "id": "f03a7555", + "metadata": {}, + "source": [ + "### Benchmark for numpy computed gradient" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a85fe92e", + "metadata": {}, + "outputs": [], + "source": [ + "def benchmark_kalman_gradients_np(a_pred_seq, P_pred_seq, obs_data, a0, P0, T, Z, R, H, Q):\n", + " results = defaultdict(dict)\n", + " forward_pass = 0\n", + " backprop = 0\n", + " kalman_fn = pytensor.function(inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=(a_pred_seq, P_pred_seq))\n", + "\n", + " for _ in range(20):\n", + "\n", + " # --- forward pass ---\n", + " t0 = perf_counter()\n", + " a_pred, P_pred = kalman_fn(obs_data[:, np.newaxis],\n", + " a0,\n", + " P0,\n", + " T,\n", + " Z,\n", + " H,\n", + " R@Q@R.T,)\n", + " t1 = perf_counter()\n", + " forward_pass += (t1 - t0)/20\n", + " \n", + "\n", + " # --- Backprop ---\n", + " t0 = perf_counter()\n", + " _ = compute_grad_a0(\n", + " obs_data,\n", + " a0,\n", + " P0,\n", + " a_pred,\n", + " P_pred,\n", + " Z,\n", + " H,\n", + " T,)\n", + " t1 = perf_counter()\n", + " backprop += (t1 - t0)/20\n", + "\n", + " results[\"Forward pass\"] = forward_pass \n", + " results[\"Backprop\"] = backprop\n", + "\n", + " return results" + ] + }, + { + "cell_type": "markdown", + "id": "b413f411", + "metadata": {}, + "source": [ + "### Comparison" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "27a60fb3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\jeanv\\miniconda3\\envs\\CausalPy\\Lib\\site-packages\\pytensor\\tensor\\rewriting\\elemwise.py:954: UserWarning: Loop fusion failed because the resulting node would exceed the kernel argument limit.\n", + " warn(\n" + ] + } + ], + "source": [ + "results = benchmark_kalman_gradients(loss, obs_data, a0, P0, T, Z, R, H, Q)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "a413c8e9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "defaultdict(, {'exec_time': 0.11510749500157547})\n" + ] + } + ], + "source": [ + "print(results)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "d35b98d6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\jeanv\\miniconda3\\envs\\CausalPy\\Lib\\site-packages\\pytensor\\tensor\\rewriting\\elemwise.py:954: UserWarning: Loop fusion failed because the resulting node would exceed the kernel argument limit.\n", + " warn(\n" + ] + } + ], + "source": [ + "results_op = benchmark_kalman_gradients(loss_op, obs_data, a0, P0, T, Z, R, H, Q)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "539c18c2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "defaultdict(, {'exec_time': 0.18277070000040113})\n" + ] + } + ], + "source": [ + "print(results_op)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e633e75", + "metadata": {}, + "outputs": [], + "source": [ + "results_np = benchmark_kalman_gradients_np(a_pred_seq, P_pred_seq, obs_data, a0, P0, T, Z, R, H, Q)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7118dfec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "defaultdict(, {'Forward pass': 0.016379594999978053, 'Backprop': 0.0034159099999897078})\n" + ] + } + ], + "source": [ + "print(results_np)" + ] + }, + { + "cell_type": "markdown", + "id": "d77cf70b", + "metadata": {}, + "source": [ + "## Error observation" + ] + }, + { + "cell_type": "markdown", + "id": "90fabd6f", + "metadata": {}, + "source": [ + "### Comparing the gradient with respect to a0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbae0189", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\jeanv\\miniconda3\\envs\\CausalPy\\Lib\\site-packages\\pytensor\\tensor\\rewriting\\elemwise.py:954: UserWarning: Loop fusion failed because the resulting node would exceed the kernel argument limit.\n", + " warn(\n", + "c:\\Users\\jeanv\\miniconda3\\envs\\CausalPy\\Lib\\site-packages\\pytensor\\tensor\\rewriting\\elemwise.py:954: UserWarning: Loop fusion failed because the resulting node would exceed the kernel argument limit.\n", + " warn(\n" + ] + } + ], + "source": [ + "# First the classic way with autodiff\n", + "\n", + "grad_list = pt.grad(loss, [a0_sym])\n", + "f_grad = pytensor.function(\n", + " inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=grad_list,\n", + ")\n", + "\n", + "grad_a0 = f_grad(obs_data[:, np.newaxis], a0, P0, T, Z, H, R @ Q @ R.T)\n", + "\n", + "# Now using our OpFromGraph custom gradient\n", + "\n", + "grad_list_op = pt.grad(loss_op, [a0_sym])\n", + "f_grad = pytensor.function(\n", + " inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=grad_list_op,\n", + ")\n", + "\n", + "grad_a0_op = f_grad(obs_data[:, np.newaxis], a0, P0, T, Z, H, R @ Q @ R.T)\n", + "\n", + "# And here using our handmaid numpy backprop\n", + "\n", + "kalman_fn = pytensor.function(inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=(a_pred_seq, P_pred_seq))\n", + "a_pred, P_pred = kalman_fn(obs_data[:, np.newaxis], a0, P0, T, Z, H, R@Q@R.T)\n", + "\n", + "grad_a0_np = compute_grad_a0(obs_data, a0, P0, a_pred, P_pred, Z, H, T)[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3a114b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Comparison between classic a0 gradient and our custom OpFromGraph : True\n", + "Comparison between classic a0 gradient and our handmade NumPy backprop : True\n" + ] + } + ], + "source": [ + "print(\"Comparison between classic a0 gradient and our custom OpFromGraph :\", np.allclose(grad_a0, grad_a0_op))\n", + "print(\"Comparison between classic a0 gradient and our handmade NumPy backprop :\", np.allclose(grad_a0, grad_a0_np))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "867d5e2f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\jeanv\\miniconda3\\envs\\CausalPy\\Lib\\site-packages\\pytensor\\tensor\\rewriting\\elemwise.py:954: UserWarning: Loop fusion failed because the resulting node would exceed the kernel argument limit.\n", + " warn(\n", + "c:\\Users\\jeanv\\miniconda3\\envs\\CausalPy\\Lib\\site-packages\\pytensor\\tensor\\rewriting\\elemwise.py:954: UserWarning: Loop fusion failed because the resulting node would exceed the kernel argument limit.\n", + " warn(\n" + ] + } + ], + "source": [ + "# First the classic way with autodiff\n", + "\n", + "grad_list = pt.grad(loss, [data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym])\n", + "f_grad = pytensor.function(\n", + " inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=grad_list,\n", + ")\n", + "\n", + "grad_a0 = f_grad(obs_data[:, np.newaxis], a0, P0, T, Z, H, R @ Q @ R.T)\n", + "\n", + "# Now using our OpFromGraph custom gradient\n", + "\n", + "grad_list_op = pt.grad(loss_op, [data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym])\n", + "f_grad = pytensor.function(\n", + " inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],\n", + " outputs=grad_list_op,\n", + ")\n", + "\n", + "grad_a0_op = f_grad(obs_data[:, np.newaxis], a0, P0, T, Z, H, R @ Q @ R.T)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25f0a57b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Comparison between classic y gradient and our custom OpFromGraph : True\n", + "Comparison between classic a0 gradient and our custom OpFromGraph : True\n", + "Comparison between classic P0 gradient and our custom OpFromGraph : True\n", + "Comparison between classic T gradient and our custom OpFromGraph : True\n", + "Comparison between classic Z gradient and our custom OpFromGraph : True\n", + "Comparison between classic H gradient and our custom OpFromGraph : True\n", + "Comparison between classic Q gradient and our custom OpFromGraph : True\n" + ] + } + ], + "source": [ + "print(\"Comparison between classic y gradient and our custom OpFromGraph :\", np.allclose(grad_a0[0], grad_a0_op[0]))\n", + "print(\"Comparison between classic a0 gradient and our custom OpFromGraph :\", np.allclose(grad_a0[1], grad_a0_op[1]))\n", + "print(\"Comparison between classic P0 gradient and our custom OpFromGraph :\", np.allclose((grad_a0[2] + grad_a0[2].T)/2, grad_a0_op[2]))\n", + "print(\"Comparison between classic T gradient and our custom OpFromGraph :\", np.allclose(grad_a0[3], grad_a0_op[3]))\n", + "print(\"Comparison between classic Z gradient and our custom OpFromGraph :\", np.allclose(grad_a0[4], grad_a0_op[4]))\n", + "print(\"Comparison between classic H gradient and our custom OpFromGraph :\", np.allclose(grad_a0[5], grad_a0_op[5]))\n", + "print(\"Comparison between classic Q gradient and our custom OpFromGraph :\", np.allclose((grad_a0[6] + grad_a0[6].T)/2, grad_a0_op[6]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "CausalPy", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}