Improve dot lift rewrites #1471
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR was motivated by the partial jacobian computation example in JAX discussed in jax-ml/jax#5904 (comment)
After #1228 it's actually easier to do this sort of optimization in PyTensor since there's no scan to worry about. We already have a bunch of rewrites to lift subtensor operations through elemwise and dots, but we did not have to lift it through blockwise (and blockwise dot - aka matmul). This PR addresses this.
Some notes on each commit:
Do constant_folding in python mode. This is not related to this PR but I noticed a test was taking 10x longer than the others just because there was a simple constant folding operation being triggered in the rewrites, and the whole c-cache was being loaded. This incurs a one time penalty that's pretty large. For users, not interested in the C backend at all, there's no reason to involve the machinery. One single python eval should be pretty fast anyway.
Simplified
local_upcast_elemwise
. This rewrite was too complex and wasteful, in that it wrapped constants in symbolic expand_dims / alloc + cast. I just do it in numpy directly. This reduces the number of rewrite iterations.Bunch of improvements to rewrites. Including lifting index operations on the batch dimensions of blockwise, and expanding the dot subtensor lift to work with the Blockwise case. This rewrite predates Blockwise. Others are self-explanatory.
Canonicalize matvec, vecmat, vecdot internally to all use
matmul
(i.e., Blockwise of 2x2 dot operation). This makes things simpler for our rewrites, because we only need to worry about one case.The pre-existing
test_local_batched_matmul_to_core_matmul
rewrite was extend to better address cases of batched matvec, vecmat, and vecdot (batch dimensions are moved to the core dimension). It now moves non-ovelapping batch dimensions of both inputs to their core dimensions. It further tries to avoid reshape (needed when combining multiple batch/core dimensions), so that subtensor_lift rewrites mentioned above can work fine through them.Benchmark result added in the last commit:
(Note that vectorize=True goes from underperforming (28ms) to overperforming (.37 ms).
vectorized jacobian code before:
and after:
📚 Documentation preview 📚: https://pytensor--1471.org.readthedocs.build/en/1471/