Skip to content

Remove uses of numba_basic.global_numba_func #1535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,24 +402,22 @@
return deepcopyop


@numba_njit
def makeslice(*x):
return slice(*x)


@numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
return global_numba_func(makeslice)

@numba_njit
def makeslice(*x):
return slice(*x)

Check warning on line 409 in pytensor/link/numba/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/basic.py#L407-L409

Added lines #L407 - L409 were not covered by tests

@numba_njit
def shape(x):
return np.asarray(np.shape(x))
return makeslice

Check warning on line 411 in pytensor/link/numba/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/basic.py#L411

Added line #L411 was not covered by tests


@numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs):
return global_numba_func(shape)
@numba_njit
def shape(x):
return np.asarray(np.shape(x))

return shape


@numba_funcify.register(Shape_i)
Expand Down
130 changes: 60 additions & 70 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,16 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
)(scalar_op_fn)


@numba_basic.numba_njit
def switch(condition, x, y):
if condition:
return x
else:
return y


@numba_funcify.register(Switch)
def numba_funcify_Switch(op, node, **kwargs):
return numba_basic.global_numba_func(switch)
@numba_basic.numba_njit
def switch(condition, x, y):
if condition:
return x
else:
return y

return switch


def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
Expand Down Expand Up @@ -197,34 +196,32 @@ def cast(x):
return cast


@numba_basic.numba_njit
def identity(x):
return x


@numba_funcify.register(Identity)
@numba_funcify.register(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs):
return numba_basic.global_numba_func(identity)


@numba_basic.numba_njit
def clip(_x, _min, _max):
x = numba_basic.to_scalar(_x)
_min_scalar = numba_basic.to_scalar(_min)
_max_scalar = numba_basic.to_scalar(_max)

if x < _min_scalar:
return _min_scalar
elif x > _max_scalar:
return _max_scalar
else:
@numba_basic.numba_njit
def identity(x):
return x

return identity


@numba_funcify.register(Clip)
def numba_funcify_Clip(op, **kwargs):
return numba_basic.global_numba_func(clip)
@numba_basic.numba_njit
def clip(x, min_val, max_val):
x = numba_basic.to_scalar(x)
min_scalar = numba_basic.to_scalar(min_val)
max_scalar = numba_basic.to_scalar(max_val)

if x < min_scalar:
return min_scalar
elif x > max_scalar:
return max_scalar
else:
return x

return clip


@numba_funcify.register(Composite)
Expand All @@ -239,79 +236,72 @@ def numba_funcify_Composite(op, node, **kwargs):
return composite_fn


@numba_basic.numba_njit
def second(x, y):
return y


@numba_funcify.register(Second)
def numba_funcify_Second(op, node, **kwargs):
return numba_basic.global_numba_func(second)

@numba_basic.numba_njit
def second(x, y):
return y

@numba_basic.numba_njit
def reciprocal(x):
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
# `x` is an `int`
return 1 / x
return second


@numba_funcify.register(Reciprocal)
def numba_funcify_Reciprocal(op, node, **kwargs):
return numba_basic.global_numba_func(reciprocal)

@numba_basic.numba_njit
def reciprocal(x):
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
# `x` is an `int`
return 1 / x

@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))
return reciprocal


@numba_funcify.register(Sigmoid)
def numba_funcify_Sigmoid(op, node, **kwargs):
return numba_basic.global_numba_func(sigmoid)

@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))

@numba_basic.numba_njit
def gammaln(x):
return math.lgamma(x)
return sigmoid


@numba_funcify.register(GammaLn)
def numba_funcify_GammaLn(op, node, **kwargs):
return numba_basic.global_numba_func(gammaln)

@numba_basic.numba_njit
def gammaln(x):
return math.lgamma(x)

@numba_basic.numba_njit
def logp1mexp(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
else:
return np.log(-np.expm1(x))
return gammaln


@numba_funcify.register(Log1mexp)
def numba_funcify_Log1mexp(op, node, **kwargs):
return numba_basic.global_numba_func(logp1mexp)

@numba_basic.numba_njit
def logp1mexp(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
else:
return np.log(-np.expm1(x))

@numba_basic.numba_njit
def erf(x):
return math.erf(x)
return logp1mexp


@numba_funcify.register(Erf)
def numba_funcify_Erf(op, **kwargs):
return numba_basic.global_numba_func(erf)

@numba_basic.numba_njit
def erf(x):
return math.erf(x)

@numba_basic.numba_njit
def erfc(x):
return math.erfc(x)
return erf


@numba_funcify.register(Erfc)
def numba_funcify_Erfc(op, **kwargs):
return numba_basic.global_numba_func(erfc)
@numba_basic.numba_njit
def erfc(x):
return math.erfc(x)

return erfc


@numba_funcify.register(Softplus)
Expand Down