Skip to content

Fix CheckAndRaise Op C implementation #1521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 9, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jul 4, 2025

This showed up mysteriously in #1471

I wanted to make constant_fold use python mode (no reason to compile C stuff for a single eval), and some tests started failing. After digging, found out the C-implementation was broken when the conditions were tensors.

I simplified the Op but converting all inputs to boolean scalars. I also decided not to constant_fold Asserts that would fail (remember that before this PR they would wrongly constant_fold without raising), so users can still decide what to do with them after graph rewrite.

I changed the JAX dispatch to raise the error if the Assert would be known to have failed (now that it is not constant-folded)

This PR also changes the constant_fold to use Python (the reason I found the bug in the first place)


📚 Documentation preview 📚: https://pytensor--1521.org.readthedocs.build/en/1521/

fail_code = props["fail"]
param_struct_name = props["params"]
msg = self.msg.replace('"', '\\"').replace("\n", "\\n")

for idx, cond_name in enumerate(cond_names):
if isinstance(node.inputs[0].type, DenseTensorType):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check was the source of the bug. It should have been checking for the type of the condition, not the first input

@ricardoV94 ricardoV94 changed the title Fix check and raise op C implementation Fix CheckAndRaise Op C implementation Jul 4, 2025
@ricardoV94 ricardoV94 removed the request for review from jessegrabowski July 4, 2025 11:56
@ricardoV94 ricardoV94 force-pushed the fix_check_and_raise_op branch 3 times, most recently from f9b7731 to a3eb900 Compare July 5, 2025 05:55
For performance, the Op now always converts the inputs to boolean scalars.

Also do not constant-fold if it would raise.
@ricardoV94 ricardoV94 force-pushed the fix_check_and_raise_op branch 2 times, most recently from 08e2e1a to 543230a Compare July 8, 2025 08:59
Comment on lines -59 to -70
def __eq__(self, other):
if type(self) is not type(other):
return False

if self.msg == other.msg and self.exc_type == other.exc_type:
return True

return False

def __hash__(self):
return hash((self.msg, self.exc_type))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is handled by the props automatically

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes several issues in the CheckAndRaise Op by simplifying its implementation, improving constant folding behavior, and updating JAX dispatch to correctly raise on failing asserts. Tests are updated to reflect these changes.

  • Cast all assert conditions to boolean scalars and streamline C code generation.
  • Use the Python implementation for constant folding when available, and do not fold failing asserts.
  • Modify the JAX funcify rule to immediately raise on known-false constants.

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated no comments.

Show a summary per file
File Description
pytensor/raise_op.py Simplify CheckAndRaise node creation, unify scalar conversion, refine C code & cache version, add do_constant_folding
pytensor/tensor/rewriting/basic.py Add a Python-mode fallback for constant folding, refine useless-assert removal
pytensor/link/jax/dispatch/basic.py Update JAX dispatch to raise for constant-false asserts and adjust warning text
pytensor/link/pytorch/dispatch/basic.py Register ScalarFromTensor and adjust truth check for CheckAndRaise
tests/ Refresh tests for raise op, broadcasting, math ops, and rewriting passes
Comments suppressed due to low confidence (2)

pytensor/tensor/rewriting/basic.py:125

  • In the c_code method for CheckAndRaise, the non-DenseTensorType path falls through to return join(check, res), but res is only defined inside the DenseTensorType branch. You need to define the scalar-res branch properly or guard this return with an else that initializes res.
        )

pytensor/link/jax/dispatch/basic.py:77

  • The Assert Op (a subclass of CheckAndRaise) is no longer explicitly registered for JAX funcify, so Assert nodes won’t be handled by this rule. Re-add @jax_funcify.register(Assert) or ensure subclass dispatch covers it.
@jax_funcify.register(CheckAndRaise)

@@ -1559,7 +1559,7 @@ def test_local_merge_alloc():
output = pt.alloc(pt.alloc(m, y, 1, 1), x, y2, z, w)
f = function([m, x, y, y2, z, w], output, mode=rewrite_mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 3
assert len(topo) == 4
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new Op count is the ScalarFromTensor

@ricardoV94 ricardoV94 force-pushed the fix_check_and_raise_op branch from 543230a to 9c79f9a Compare July 8, 2025 09:40
Copy link

codecov bot commented Jul 8, 2025

Codecov Report

Attention: Patch coverage is 90.90909% with 3 lines in your changes missing coverage. Please review.

Project coverage is 82.04%. Comparing base (47a15c6) to head (9c79f9a).
Report is 15 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/jax/dispatch/basic.py 66.66% 1 Missing and 1 partial ⚠️
pytensor/link/pytorch/dispatch/basic.py 83.33% 1 Missing ⚠️

❌ Your patch status has failed because the patch coverage (90.90%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1521      +/-   ##
==========================================
- Coverage   82.04%   82.04%   -0.01%     
==========================================
  Files         231      230       -1     
  Lines       52365    52347      -18     
  Branches     9217     9210       -7     
==========================================
- Hits        42963    42947      -16     
- Misses       7094     7098       +4     
+ Partials     2308     2302       -6     
Files with missing lines Coverage Δ
pytensor/raise_op.py 96.73% <100.00%> (-0.30%) ⬇️
pytensor/tensor/rewriting/basic.py 95.85% <100.00%> (+0.01%) ⬆️
pytensor/link/pytorch/dispatch/basic.py 85.91% <83.33%> (-0.22%) ⬇️
pytensor/link/jax/dispatch/basic.py 82.71% <66.66%> (-1.90%) ⬇️

... and 14 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 merged commit 6770f46 into pymc-devs:main Jul 9, 2025
72 of 73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants