Skip to content

Add dot to xtensor #1475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: labeled_tensors
Choose a base branch
from

Conversation

AllenDowney
Copy link

@AllenDowney AllenDowney commented Jun 14, 2025

Add dot method to xtensor

This PR adds a dot method to xtensor that matches xarray's dot operation behavior. The method supports:

  • Matrix-vector dot products
  • Matrix-matrix dot products
  • Contracting over specific dimensions using string or list of dimensions
  • Contracting over all dimensions using ellipsis (...)

Key implementation details:

  1. Added XDot operation with a sum_result flag to handle the ellipsis case
  2. When dims=..., the operation:
    • Contracts over all matching dimensions between inputs
    • Sets sum_result=True to indicate that remaining dimensions should be summed
    • Returns a scalar result (dims=(), shape=())
  3. The rewrite rule (lower_dot) handles the actual computation:
    • Uses tensordot to contract over specified dimensions
    • When sum_result=True, sums over all remaining axes to produce a scalar
    • Preserves dimension names in the output tensor

This implementation ensures that xtensor's dot operation behaves consistently with xarray's dot operation, particularly for the ellipsis case which performs a full contraction resulting in a scalar.


📚 Documentation preview 📚: https://pytensor--1475.org.readthedocs.build/en/1475/

@AllenDowney
Copy link
Author

@ricardoV94 Here's a first pass. Getting the dims=... case to work took some hackery. Let me know what you think.

(I'll fix the lint errors)

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good, I would just suggest removing the Sum behavior from XDot to avoid redundancy, and simplify the lowering rules

@AllenDowney
Copy link
Author

@ricardoV94 Is this what you had in mind, composing XDot and Sum in the helper function?

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what the PR has so far looks good, but we need to agree on the API, if we want to match xarray behaviour we are still somewhat far if we want something closer to np.dot instead it is mostly done. I haven't reviewed the lowering or tests in depth given the high level discussion still pending but at the very least the docstrings would need to be updated

@AllenDowney
Copy link
Author

AllenDowney commented Jun 15, 2025

I've changed the argument name to dim to match xarray.DataArray.dot

I've also generalized the logic to handle some cases that were not handled correctly in the previous iteration.

But there's one case that I think cannot be lowered to tensordot :

    x_test = DataArray(np.arange(120).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d"))
    y_test = DataArray(np.arange(360).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e"))
    expected = x_test.dot(y_test, dim=("b", "d"))

In this case there are three dimensions we could perform dot on, but we specify that we only want the first and third. The result from tensordot does not have the right shape in this case.

Two options from here:

  1. Declare that we don't handle cases like this, or

  2. Try to lower it to einsum

Thoughts? @OriolAbril @ricardoV94

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 15, 2025

Probably (and not surprisingly) @OriolAbril was right and we need einsum to express xarray dot. I didn't know, about the case where it does elementwise multiplication but no reduction on shared dims

@AllenDowney
Copy link
Author

Looks like it. On the other hand, tensordot can handle 99.999% of the cases it's going to see in the wild. What do you think of using it most of the time and, for the rare exceptions, either throwing NotImplemented or falling back to einsum?

@ricardoV94
Copy link
Member

Looks like it. On the other hand, tensordot can handle 99.999% of the cases it's going to see in the wild. What do you think of using it most of the time and, for the rare exceptions, either throwing NotImplemented or falling back to einsum?

I would go with einsum, it will call tensordot internally for the cases that require it. I don't remember if we have the API to not require silly string letters

@AllenDowney
Copy link
Author

@ricardoV94 Here's a version using einsum -- I think it looks good.

I tried a lot of test cases. We could pare that down once the dust settles.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good, just a small suggestion on where to place the dims checks.

I'll give some time if @OriolAbril wants to leave a review before we merge

@AllenDowney
Copy link
Author

@ricardoV94 ready for another look

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me, will let sometime for @OriolAbril before merging

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants