Make convolve mode symbolic to avoid unnecessary large convolution in gradient graph #1522
+278
−276
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Instead of statically parametrizing the convolution type at the Op level, it now uses a scalar boolean that can be set symbolically. This was fine except for JAX, where we can't branch symbolically like that.
As long as the mode is constant (which is the case when the original convolve was full, or the core shapes are statically known) this should be fine, except for one hiccup. In the dispatch of Blockwise we couldn't see the outer inputs.
I added some functionality when we create the dummy core node to propagate inputs if these don't have batch dimensions, which means they won't change over iterations and so making compile or rewrite decisions based on this should be safe. This can also be used for infer_shape for instance, which could help with lowering some Ops to numba
These changes would also allow us to compile a constant convolve mode in C/Numba, but benchmarks didn't show any gains so I didn't bother doing that. In any case, future implementations of Blockwise for certain Ops can make use of that information.
Relevant benchmark tests:
Note the worst case scenario when we were doing a full convolution for the smaller input in the gradient of a valid convolution.
📚 Documentation preview 📚: https://pytensor--1522.org.readthedocs.build/en/1522/