Skip to content

#332 Closed form gradients for Kalman filter #557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

JeanVanDyk
Copy link

Here you'll find, in the notebook section, the notebook I've used to compare execution time between using autodiff and the analytic gradients.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@jessegrabowski jessegrabowski added enhancements New feature or request statespace labels Aug 2, 2025
@ricardoV94
Copy link
Member

Suggestion, run custom gradient vs default autodiff in pytensor and time it. Print the compiled graph (.dprint() on the function) to see if what pytensor autodiff is doing isn't already similar.

I see the original paper compares with PyTorch, which IMO isn't very clever AD, specially in terms of memory optimization. I'm not sure their 38x speedup / memory improvement also holds against pytensor (or jax).

@JeanVanDyk
Copy link
Author

Thanks @ricardoV94!

I’ve tried looking into it, but it quickly becomes messy… Do you have a particular method for comparing the dprint output? I saw that you can name operations to make things clearer, but that doesn’t seem very efficient given the hundreds of lines I get.

I also tried timing it, and the gradient with the closed-form expression actually performs worse. I’m inclined to think that autodiff is still being used under the hood, especially since the runtime scales in pretty much the same way for both forms, depending on the number of states.

@ricardoV94
Copy link
Member

Can you share the timing code, that's easier to give feedback over.

Re the dprint that was just curiosity just paste it after you compile (yeah it will be long)

@JeanVanDyk
Copy link
Author

You can find it at the end of the notebook, here is the benchmark function :

def benchmark_kalman_gradients(loss, obs_data, a0, P0, T, Z, R, H, Q):
    results = defaultdict(dict)
    exec_time = 0

    grad_list = pt.grad(loss, [data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym])
    f_grad = pytensor.function(
        inputs=[data_sym, a0_sym, P0_sym, T_sym, Z_sym, H_sym, Q_sym],
        outputs=grad_list,
    )

    for _ in range(20):
    
        # --- exécution ---
        t0 = perf_counter()
        _ = f_grad(
            obs_data[:, np.newaxis],
            a0,
            P0,
            T,
            Z,
            H,
            R @ Q @ R.T,
        )
        t1 = perf_counter()
        exec_time += (t1 - t0)/20
    
    
    results["exec_time"] = exec_time

    return results
    ```

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request statespace
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants