-
Notifications
You must be signed in to change notification settings - Fork 137
Fix CheckAndRaise
Op C implementation
#1521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
fail_code = props["fail"] | ||
param_struct_name = props["params"] | ||
msg = self.msg.replace('"', '\\"').replace("\n", "\\n") | ||
|
||
for idx, cond_name in enumerate(cond_names): | ||
if isinstance(node.inputs[0].type, DenseTensorType): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check was the source of the bug. It should have been checking for the type of the condition, not the first input
CheckAndRaise
Op C implementation
f9b7731
to
a3eb900
Compare
For performance, the Op now always converts the inputs to boolean scalars. Also do not constant-fold if it would raise.
08e2e1a
to
543230a
Compare
def __eq__(self, other): | ||
if type(self) is not type(other): | ||
return False | ||
|
||
if self.msg == other.msg and self.exc_type == other.exc_type: | ||
return True | ||
|
||
return False | ||
|
||
def __hash__(self): | ||
return hash((self.msg, self.exc_type)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is handled by the props automatically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes several issues in the CheckAndRaise
Op by simplifying its implementation, improving constant folding behavior, and updating JAX dispatch to correctly raise on failing asserts. Tests are updated to reflect these changes.
- Cast all assert conditions to boolean scalars and streamline C code generation.
- Use the Python implementation for constant folding when available, and do not fold failing asserts.
- Modify the JAX funcify rule to immediately raise on known-false constants.
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
pytensor/raise_op.py | Simplify CheckAndRaise node creation, unify scalar conversion, refine C code & cache version, add do_constant_folding |
pytensor/tensor/rewriting/basic.py | Add a Python-mode fallback for constant folding, refine useless-assert removal |
pytensor/link/jax/dispatch/basic.py | Update JAX dispatch to raise for constant-false asserts and adjust warning text |
pytensor/link/pytorch/dispatch/basic.py | Register ScalarFromTensor and adjust truth check for CheckAndRaise |
tests/ | Refresh tests for raise op, broadcasting, math ops, and rewriting passes |
Comments suppressed due to low confidence (2)
pytensor/tensor/rewriting/basic.py:125
- In the
c_code
method forCheckAndRaise
, the non-DenseTensorType
path falls through toreturn join(check, res)
, butres
is only defined inside theDenseTensorType
branch. You need to define the scalar-res branch properly or guard this return with anelse
that initializesres
.
)
pytensor/link/jax/dispatch/basic.py:77
- The
Assert
Op (a subclass ofCheckAndRaise
) is no longer explicitly registered for JAX funcify, soAssert
nodes won’t be handled by this rule. Re-add@jax_funcify.register(Assert)
or ensure subclass dispatch covers it.
@jax_funcify.register(CheckAndRaise)
@@ -1559,7 +1559,7 @@ def test_local_merge_alloc(): | |||
output = pt.alloc(pt.alloc(m, y, 1, 1), x, y2, z, w) | |||
f = function([m, x, y, y2, z, w], output, mode=rewrite_mode) | |||
topo = f.maker.fgraph.toposort() | |||
assert len(topo) == 3 | |||
assert len(topo) == 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new Op count is the ScalarFromTensor
543230a
to
9c79f9a
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (90.90%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1521 +/- ##
==========================================
- Coverage 82.04% 82.04% -0.01%
==========================================
Files 231 230 -1
Lines 52365 52347 -18
Branches 9217 9210 -7
==========================================
- Hits 42963 42947 -16
- Misses 7094 7098 +4
+ Partials 2308 2302 -6
🚀 New features to boost your workflow:
|
This showed up mysteriously in #1471
I wanted to make constant_fold use python mode (no reason to compile C stuff for a single eval), and some tests started failing. After digging, found out the C-implementation was broken when the conditions were tensors.
I simplified the Op but converting all inputs to boolean scalars. I also decided not to constant_fold Asserts that would fail (remember that before this PR they would wrongly constant_fold without raising), so users can still decide what to do with them after graph rewrite.
I changed the JAX dispatch to raise the error if the Assert would be known to have failed (now that it is not constant-folded)
This PR also changes the constant_fold to use Python (the reason I found the bug in the first place)
📚 Documentation preview 📚: https://pytensor--1521.org.readthedocs.build/en/1521/