-
Notifications
You must be signed in to change notification settings - Fork 135
Add dot to xtensor #1475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: labeled_tensors
Are you sure you want to change the base?
Add dot to xtensor #1475
Conversation
@ricardoV94 Here's a first pass. Getting the dims=... case to work took some hackery. Let me know what you think. (I'll fix the lint errors) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good, I would just suggest removing the Sum behavior from XDot to avoid redundancy, and simplify the lowering rules
@ricardoV94 Is this what you had in mind, composing XDot and Sum in the helper function? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what the PR has so far looks good, but we need to agree on the API, if we want to match xarray behaviour we are still somewhat far if we want something closer to np.dot
instead it is mostly done. I haven't reviewed the lowering or tests in depth given the high level discussion still pending but at the very least the docstrings would need to be updated
I've changed the argument name to I've also generalized the logic to handle some cases that were not handled correctly in the previous iteration. But there's one case that I think cannot be lowered to
In this case there are three dimensions we could perform Two options from here:
Thoughts? @OriolAbril @ricardoV94 |
Probably (and not surprisingly) @OriolAbril was right and we need einsum to express xarray dot. I didn't know, about the case where it does elementwise multiplication but no reduction on shared dims |
Looks like it. On the other hand, |
I would go with einsum, it will call tensordot internally for the cases that require it. I don't remember if we have the API to not require silly string letters |
@ricardoV94 Here's a version using I tried a lot of test cases. We could pare that down once the dust settles. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good, just a small suggestion on where to place the dims checks.
I'll give some time if @OriolAbril wants to leave a review before we merge
@ricardoV94 ready for another look |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good for me, will let sometime for @OriolAbril before merging
Add dot method to xtensor
This PR adds a dot method to xtensor that matches xarray's dot operation behavior. The method supports:
Key implementation details:
This implementation ensures that xtensor's dot operation behaves consistently with xarray's dot operation, particularly for the ellipsis case which performs a full contraction resulting in a scalar.
📚 Documentation preview 📚: https://pytensor--1475.org.readthedocs.build/en/1475/