-
Notifications
You must be signed in to change notification settings - Fork 137
Fix CheckAndRaise
Op C implementation
#1521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,15 +2,13 @@ | |
|
||
from textwrap import indent | ||
|
||
import numpy as np | ||
|
||
from pytensor.gradient import DisconnectedType | ||
from pytensor.graph.basic import Apply, Variable | ||
from pytensor.graph.basic import Apply, Constant, Variable | ||
from pytensor.graph.replace import _vectorize_node | ||
from pytensor.link.c.op import COp | ||
from pytensor.link.c.params_type import ParamsType | ||
from pytensor.link.c.type import Generic | ||
from pytensor.scalar.basic import ScalarType | ||
from pytensor.scalar.basic import ScalarType, as_scalar | ||
from pytensor.tensor.type import DenseTensorType | ||
|
||
|
||
|
@@ -56,18 +54,6 @@ def __str__(self): | |
msg = self.msg | ||
return f"{name}{{raises={exc_name}, msg='{msg}'}}" | ||
|
||
def __eq__(self, other): | ||
if type(self) is not type(other): | ||
return False | ||
|
||
if self.msg == other.msg and self.exc_type == other.exc_type: | ||
return True | ||
|
||
return False | ||
|
||
def __hash__(self): | ||
return hash((self.msg, self.exc_type)) | ||
|
||
def make_node(self, value: Variable, *conds: Variable): | ||
""" | ||
|
||
|
@@ -84,12 +70,10 @@ def make_node(self, value: Variable, *conds: Variable): | |
if not isinstance(value, Variable): | ||
value = pt.as_tensor_variable(value) | ||
|
||
conds = [ | ||
pt.as_tensor_variable(c) if not isinstance(c, Variable) else c | ||
for c in conds | ||
] | ||
|
||
assert all(c.type.ndim == 0 for c in conds) | ||
conds = [as_scalar(c) for c in conds] | ||
for i, cond in enumerate(conds): | ||
if cond.dtype != "bool": | ||
conds[i] = cond.astype("bool") | ||
|
||
return Apply( | ||
self, | ||
|
@@ -101,7 +85,7 @@ def perform(self, node, inputs, outputs): | |
(out,) = outputs | ||
val, *conds = inputs | ||
out[0] = val | ||
if not np.all(conds): | ||
if not all(conds): | ||
raise self.exc_type(self.msg) | ||
|
||
def grad(self, input, output_gradients): | ||
|
@@ -117,38 +101,20 @@ def c_code(self, node, name, inames, onames, props): | |
) | ||
value_name, *cond_names = inames | ||
out_name = onames[0] | ||
check = [] | ||
fail_code = props["fail"] | ||
param_struct_name = props["params"] | ||
msg = self.msg.replace('"', '\\"').replace("\n", "\\n") | ||
|
||
for idx, cond_name in enumerate(cond_names): | ||
if isinstance(node.inputs[0].type, DenseTensorType): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check was the source of the bug. It should have been checking for the type of the condition, not the first input |
||
check.append( | ||
f""" | ||
if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{ | ||
PyObject * exc_type = {param_struct_name}->exc_type; | ||
Py_INCREF(exc_type); | ||
PyErr_SetString(exc_type, "{msg}"); | ||
Py_XDECREF(exc_type); | ||
{indent(fail_code, " " * 4)} | ||
}} | ||
""" | ||
) | ||
else: | ||
check.append( | ||
f""" | ||
if({cond_name} == 0) {{ | ||
PyObject * exc_type = {param_struct_name}->exc_type; | ||
Py_INCREF(exc_type); | ||
PyErr_SetString(exc_type, "{msg}"); | ||
Py_XDECREF(exc_type); | ||
{indent(fail_code, " " * 4)} | ||
}} | ||
""" | ||
) | ||
|
||
check = "\n".join(check) | ||
all_conds = " && ".join(cond_names) | ||
check = f""" | ||
if(!({all_conds})) {{ | ||
PyObject * exc_type = {param_struct_name}->exc_type; | ||
Py_INCREF(exc_type); | ||
PyErr_SetString(exc_type, "{msg}"); | ||
Py_XDECREF(exc_type); | ||
{indent(fail_code, " " * 4)} | ||
}} | ||
""" | ||
|
||
if isinstance(node.inputs[0].type, DenseTensorType): | ||
res = f""" | ||
|
@@ -162,14 +128,19 @@ def c_code(self, node, name, inames, onames, props): | |
{check} | ||
{out_name} = {value_name}; | ||
""" | ||
return res | ||
|
||
return "\n".join((check, res)) | ||
|
||
def c_code_cache_version(self): | ||
return (1, 1) | ||
return (2,) | ||
|
||
def infer_shape(self, fgraph, node, input_shapes): | ||
return [input_shapes[0]] | ||
|
||
def do_constant_folding(self, fgraph, node): | ||
# Only constant-fold if the Assert does not fail | ||
return all((isinstance(c, Constant) and bool(c.data)) for c in node.inputs[1:]) | ||
|
||
|
||
class Assert(CheckAndRaise): | ||
"""Implements assertion in a computational graph. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -487,8 +487,8 @@ def test_local_remove_useless_1(self): | |
|
||
def test_local_remove_useless_2(self): | ||
"""Remove `CheckAndRaise` conditions that are always true.""" | ||
x = scalar() | ||
y = scalar() | ||
x = scalar("x") | ||
y = ps.bool("y") | ||
fg = FunctionGraph(outputs=[assert_op(x, y, 1)], clone=False) | ||
fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"]) | ||
topo = fg_res.toposort() | ||
|
@@ -497,8 +497,8 @@ def test_local_remove_useless_2(self): | |
|
||
def test_local_remove_useless_3(self): | ||
"""Don't remove `CheckAndRaise` conditions that are always false.""" | ||
x = scalar() | ||
y = scalar() | ||
x = scalar("x") | ||
y = ps.bool("y") | ||
fg = FunctionGraph(outputs=[assert_op(x, y, 0)], clone=False) | ||
fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"]) | ||
topo = fg_res.toposort() | ||
|
@@ -1559,7 +1559,7 @@ def test_local_merge_alloc(): | |
output = pt.alloc(pt.alloc(m, y, 1, 1), x, y2, z, w) | ||
f = function([m, x, y, y2, z, w], output, mode=rewrite_mode) | ||
topo = f.maker.fgraph.toposort() | ||
assert len(topo) == 3 | ||
assert len(topo) == 4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new Op count is the |
||
assert isinstance(topo[-2].op, Assert) | ||
assert isinstance(topo[-1].op, Alloc) | ||
o = f(0.0, 1, 2, 2, 3, 4) | ||
|
@@ -1616,7 +1616,7 @@ def test_local_useless_alloc(): | |
useless_alloc.rewrite(g) | ||
|
||
topo = g.toposort() | ||
assert len(topo) == 3 | ||
assert len(topo) == 4 | ||
assert isinstance(topo[-2].op, Assert) | ||
assert isinstance(topo[-1].op, Alloc) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is handled by the props automatically