Skip to content

Fix CheckAndRaise Op C implementation #1521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pytensor/link/jax/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.configdefaults import config
from pytensor.graph import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import Assert, CheckAndRaise
from pytensor.raise_op import CheckAndRaise


if config.floatX == "float64":
Expand Down Expand Up @@ -73,11 +74,14 @@
return ifelse


@jax_funcify.register(Assert)
@jax_funcify.register(CheckAndRaise)
def jax_funcify_CheckAndRaise(op, **kwargs):
def jax_funcify_CheckAndRaise(op, node, **kwargs):
conds = node.inputs[1:]
if any(isinstance(cond, Constant) and not bool(cond.data) for cond in conds):
raise op.exc_type(op.msg)

Check warning on line 81 in pytensor/link/jax/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/jax/dispatch/basic.py#L81

Added line #L81 was not covered by tests

warnings.warn(
f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as JAX tracing would remove it.""",
f"""Skipping {op} Op (assertion: {op.msg}) as JAX tracing would remove it.""",
stacklevel=2,
)

Expand Down
11 changes: 10 additions & 1 deletion pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Eye,
Join,
MakeVector,
ScalarFromTensor,
Split,
TensorFromScalar,
)
Expand Down Expand Up @@ -79,14 +80,22 @@
return type_cast


@pytorch_funcify.register(ScalarFromTensor)
def pytorch_funcify_ScalarFromTensor(op, node, **kwargs):
def scalar_from_tensor(x):
return x[()]

Check warning on line 86 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L86

Added line #L86 was not covered by tests

return scalar_from_tensor


@pytorch_funcify.register(CheckAndRaise)
def pytorch_funcify_CheckAndRaise(op, **kwargs):
error = op.exc_type
msg = op.msg

def assert_fn(x, *conditions):
for cond in conditions:
if not cond.item():
if not cond:
raise error(msg)
return x

Expand Down
77 changes: 24 additions & 53 deletions pytensor/raise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

from textwrap import indent

import numpy as np

from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import Generic
from pytensor.scalar.basic import ScalarType
from pytensor.scalar.basic import ScalarType, as_scalar
from pytensor.tensor.type import DenseTensorType


Expand Down Expand Up @@ -56,18 +54,6 @@ def __str__(self):
msg = self.msg
return f"{name}{{raises={exc_name}, msg='{msg}'}}"

def __eq__(self, other):
if type(self) is not type(other):
return False

if self.msg == other.msg and self.exc_type == other.exc_type:
return True

return False

def __hash__(self):
return hash((self.msg, self.exc_type))

Comment on lines -59 to -70
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is handled by the props automatically

def make_node(self, value: Variable, *conds: Variable):
"""

Expand All @@ -84,12 +70,10 @@ def make_node(self, value: Variable, *conds: Variable):
if not isinstance(value, Variable):
value = pt.as_tensor_variable(value)

conds = [
pt.as_tensor_variable(c) if not isinstance(c, Variable) else c
for c in conds
]

assert all(c.type.ndim == 0 for c in conds)
conds = [as_scalar(c) for c in conds]
for i, cond in enumerate(conds):
if cond.dtype != "bool":
conds[i] = cond.astype("bool")

return Apply(
self,
Expand All @@ -101,7 +85,7 @@ def perform(self, node, inputs, outputs):
(out,) = outputs
val, *conds = inputs
out[0] = val
if not np.all(conds):
if not all(conds):
raise self.exc_type(self.msg)

def grad(self, input, output_gradients):
Expand All @@ -117,38 +101,20 @@ def c_code(self, node, name, inames, onames, props):
)
value_name, *cond_names = inames
out_name = onames[0]
check = []
fail_code = props["fail"]
param_struct_name = props["params"]
msg = self.msg.replace('"', '\\"').replace("\n", "\\n")

for idx, cond_name in enumerate(cond_names):
if isinstance(node.inputs[0].type, DenseTensorType):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check was the source of the bug. It should have been checking for the type of the condition, not the first input

check.append(
f"""
if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{
PyObject * exc_type = {param_struct_name}->exc_type;
Py_INCREF(exc_type);
PyErr_SetString(exc_type, "{msg}");
Py_XDECREF(exc_type);
{indent(fail_code, " " * 4)}
}}
"""
)
else:
check.append(
f"""
if({cond_name} == 0) {{
PyObject * exc_type = {param_struct_name}->exc_type;
Py_INCREF(exc_type);
PyErr_SetString(exc_type, "{msg}");
Py_XDECREF(exc_type);
{indent(fail_code, " " * 4)}
}}
"""
)

check = "\n".join(check)
all_conds = " && ".join(cond_names)
check = f"""
if(!({all_conds})) {{
PyObject * exc_type = {param_struct_name}->exc_type;
Py_INCREF(exc_type);
PyErr_SetString(exc_type, "{msg}");
Py_XDECREF(exc_type);
{indent(fail_code, " " * 4)}
}}
"""

if isinstance(node.inputs[0].type, DenseTensorType):
res = f"""
Expand All @@ -162,14 +128,19 @@ def c_code(self, node, name, inames, onames, props):
{check}
{out_name} = {value_name};
"""
return res

return "\n".join((check, res))

def c_code_cache_version(self):
return (1, 1)
return (2,)

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]

def do_constant_folding(self, fgraph, node):
# Only constant-fold if the Assert does not fail
return all((isinstance(c, Constant) and bool(c.data)) for c in node.inputs[1:])


class Assert(CheckAndRaise):
"""Implements assertion in a computational graph.
Expand Down
20 changes: 11 additions & 9 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,20 +732,15 @@ def is_an_upcast(type1, type2):

@register_useless
@register_specialize
@node_rewriter(None)
@node_rewriter([CheckAndRaise])
def local_remove_useless_assert(fgraph, node):
if not isinstance(node.op, CheckAndRaise):
return False

new_conds = []
n_conds = len(node.inputs[1:])
for c in node.inputs[1:]:
try:
const = get_scalar_constant_value(c)

if 0 != const.ndim or const == 0:
# Should we raise an error here? How to be sure it
# is not caught?
if not const:
new_conds.append(c)
except NotScalarConstantError:
new_conds.append(c)
Expand Down Expand Up @@ -1106,8 +1101,15 @@ def unconditional_constant_folding(fgraph, node):
storage_map[o] = [None]
compute_map[o] = [False]

thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
required = thunk()
try:
thunk = node.op.make_thunk(
node, storage_map, compute_map, no_recycling=[], impl="py"
)
required = thunk()
except NotImplementedError:
# Not all Ops have a python implementation
thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
required = thunk()

# A node whose inputs are all provided should always return successfully
assert not required
Expand Down
12 changes: 6 additions & 6 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,8 @@ def test_local_remove_useless_1(self):

def test_local_remove_useless_2(self):
"""Remove `CheckAndRaise` conditions that are always true."""
x = scalar()
y = scalar()
x = scalar("x")
y = ps.bool("y")
fg = FunctionGraph(outputs=[assert_op(x, y, 1)], clone=False)
fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"])
topo = fg_res.toposort()
Expand All @@ -497,8 +497,8 @@ def test_local_remove_useless_2(self):

def test_local_remove_useless_3(self):
"""Don't remove `CheckAndRaise` conditions that are always false."""
x = scalar()
y = scalar()
x = scalar("x")
y = ps.bool("y")
fg = FunctionGraph(outputs=[assert_op(x, y, 0)], clone=False)
fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"])
topo = fg_res.toposort()
Expand Down Expand Up @@ -1559,7 +1559,7 @@ def test_local_merge_alloc():
output = pt.alloc(pt.alloc(m, y, 1, 1), x, y2, z, w)
f = function([m, x, y, y2, z, w], output, mode=rewrite_mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 3
assert len(topo) == 4
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new Op count is the ScalarFromTensor

assert isinstance(topo[-2].op, Assert)
assert isinstance(topo[-1].op, Alloc)
o = f(0.0, 1, 2, 2, 3, 4)
Expand Down Expand Up @@ -1616,7 +1616,7 @@ def test_local_useless_alloc():
useless_alloc.rewrite(g)

topo = g.toposort()
assert len(topo) == 3
assert len(topo) == 4
assert isinstance(topo[-2].op, Assert)
assert isinstance(topo[-1].op, Alloc)

Expand Down
2 changes: 1 addition & 1 deletion tests/tensor/rewriting/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ def large_fuseable_graph(self, n):
),
(fx,),
(fxv,),
4,
5,
(np.zeros_like(fxv),),
("float32",),
),
Expand Down
19 changes: 14 additions & 5 deletions tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytensor import tensor as pt
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph import rewrite_graph
from pytensor.graph.basic import Constant, applys_between, equal_computations
from pytensor.npy_2_compat import old_np_unique
from pytensor.raise_op import Assert
Expand Down Expand Up @@ -1252,11 +1253,17 @@ def test_broadcast_shape_symbolic_one_symbolic():
]

res_shape = broadcast_shape(*index_shapes, arrays_are_shapes=True)

from pytensor.graph.rewriting.utils import rewrite_graph

res_shape = rewrite_graph(res_shape)
assert res_shape[0].data == 1
assert res_shape[1].data == 1
with pytest.raises(AssertionError, match="Could not broadcast dimensions"):
# broadcast_shape doesn't treat int_div as a constant 1
res_shape[2].eval()

res_shape = broadcast_shape(
*index_shapes, arrays_are_shapes=True, allow_runtime_broadcast=True
)
res_shape = rewrite_graph(res_shape)
assert res_shape[0].data == 1
assert res_shape[1].data == 1
assert res_shape[2].data == 3
Expand Down Expand Up @@ -1294,7 +1301,9 @@ def test_broadcast_arrays():
["linspace", "logspace", "geomspace"],
ids=["linspace", "logspace", "geomspace"],
)
@pytest.mark.parametrize("dtype", [None, "int", "float"], ids=[None, "int", "float"])
@pytest.mark.parametrize(
"dtype", [None, "int64", "floatX"], ids=[None, "int64", "floatX"]
)
@pytest.mark.parametrize(
"start, stop, num_samples, endpoint, axis",
[
Expand All @@ -1310,7 +1319,7 @@ def test_broadcast_arrays():
def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis):
pt_func = getattr(pt, op)
np_func = getattr(np, op)
dtype = dtype + config.floatX[-2:] if dtype is not None else dtype
dtype = dtype if dtype != "floatX" else config.floatX
z = pt_func(start, stop, num_samples, endpoint=endpoint, axis=axis, dtype=dtype)

numpy_res = np_func(
Expand Down
31 changes: 21 additions & 10 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,30 +1412,41 @@ def _grad_list(self):
"uint8",
"uint16",
"uint32",
pytest.param("uint64", marks=pytest.mark.xfail(reason="Fails due to #770")),
pytest.param(
"uint64",
marks=pytest.mark.xfail(
condition=config.mode != "FAST_COMPILE", reason="Fails due to #770"
),
),
),
)
def test_uint(self, dtype):
itype = np.iinfo(dtype)
data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype)
n = as_tensor_variable(data)
data = np.array(
[itype.min + 3, itype.min, itype.max - 5, itype.max], dtype=dtype
)
n = vector("n", shape=(None,), dtype=dtype)

assert min(n).dtype == dtype
i_min = eval_outputs(min(n))
min_out = min(n)
assert min_out.dtype == dtype
i_min = function([n], min_out)(data)
assert i_min == itype.min

assert max(n).dtype == dtype
i_max = eval_outputs(max(n))
max_out = max(n)
assert max_out.dtype == dtype
i_max = function([n], max_out)(data)
assert i_max == itype.max

@pytest.mark.xfail(reason="Fails due to #770")
@pytest.mark.xfail(
condition=config.mode != "FAST_COMPILE", reason="Fails due to #770"
)
def test_uint64_special_value(self):
"""Example from issue #770"""
dtype = "uint64"
data = np.array([0, 9223372036854775], dtype=dtype)
n = as_tensor_variable(data)
n = vector("n", shape=(None,), dtype=dtype)

i_max = eval_outputs(max(n))
i_max = function([n], max(n))(data)
assert i_max == data.max()

def test_bool(self):
Expand Down
Loading