Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace custom cuda kernels with einsum #319

Open
sbailey opened this issue Sep 16, 2024 · 0 comments
Open

replace custom cuda kernels with einsum #319

sbailey opened this issue Sep 16, 2024 · 0 comments
Labels

Comments

@sbailey
Copy link
Collaborator

sbailey commented Sep 16, 2024

For long term maintenance simplicity, consider replacing the custom cuda kernels redrock.zscan.batch_dot_product_3d3d and batch_dot_product_3d2d with einsum magic as suggested by @dmargala :

For example, batched A.T.dot(A) and A.T.dot(b) would be:

cp.einsum("...ji,...jk", A, A)
cp.einsum("...ji,...j", A, b)

Those aren't a drop-in replacement for the call signature of batch_dot_product_3d3d, but I think we are using it for that A.T.dot(A) purpose. Profile test it against current implementation and also check for correctness.

Also consider moving functions like this into redrock.utils or a separate redrock.linalg or similar module instead of zscan.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant