diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index f6e62ae2f8..6d4a45bf30 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -402,24 +402,22 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs): return deepcopyop -@numba_njit -def makeslice(*x): - return slice(*x) - - @numba_funcify.register(MakeSlice) def numba_funcify_MakeSlice(op, **kwargs): - return global_numba_func(makeslice) - + @numba_njit + def makeslice(*x): + return slice(*x) -@numba_njit -def shape(x): - return np.asarray(np.shape(x)) + return makeslice @numba_funcify.register(Shape) def numba_funcify_Shape(op, **kwargs): - return global_numba_func(shape) + @numba_njit + def shape(x): + return np.asarray(np.shape(x)) + + return shape @numba_funcify.register(Shape_i) diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 7a8917d13e..ada4e8cc36 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -141,17 +141,16 @@ def {scalar_op_fn_name}({', '.join(input_names)}): )(scalar_op_fn) -@numba_basic.numba_njit -def switch(condition, x, y): - if condition: - return x - else: - return y - - @numba_funcify.register(Switch) def numba_funcify_Switch(op, node, **kwargs): - return numba_basic.global_numba_func(switch) + @numba_basic.numba_njit + def switch(condition, x, y): + if condition: + return x + else: + return y + + return switch def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str): @@ -197,34 +196,32 @@ def cast(x): return cast -@numba_basic.numba_njit -def identity(x): - return x - - @numba_funcify.register(Identity) @numba_funcify.register(TypeCastingOp) def numba_funcify_type_casting(op, **kwargs): - return numba_basic.global_numba_func(identity) - - -@numba_basic.numba_njit -def clip(_x, _min, _max): - x = numba_basic.to_scalar(_x) - _min_scalar = numba_basic.to_scalar(_min) - _max_scalar = numba_basic.to_scalar(_max) - - if x < _min_scalar: - return _min_scalar - elif x > _max_scalar: - return _max_scalar - else: + @numba_basic.numba_njit + def identity(x): return x + return identity + @numba_funcify.register(Clip) def numba_funcify_Clip(op, **kwargs): - return numba_basic.global_numba_func(clip) + @numba_basic.numba_njit + def clip(x, min_val, max_val): + x = numba_basic.to_scalar(x) + min_scalar = numba_basic.to_scalar(min_val) + max_scalar = numba_basic.to_scalar(max_val) + + if x < min_scalar: + return min_scalar + elif x > max_scalar: + return max_scalar + else: + return x + + return clip @numba_funcify.register(Composite) @@ -239,79 +236,72 @@ def numba_funcify_Composite(op, node, **kwargs): return composite_fn -@numba_basic.numba_njit -def second(x, y): - return y - - @numba_funcify.register(Second) def numba_funcify_Second(op, node, **kwargs): - return numba_basic.global_numba_func(second) - + @numba_basic.numba_njit + def second(x, y): + return y -@numba_basic.numba_njit -def reciprocal(x): - # TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when - # `x` is an `int` - return 1 / x + return second @numba_funcify.register(Reciprocal) def numba_funcify_Reciprocal(op, node, **kwargs): - return numba_basic.global_numba_func(reciprocal) - + @numba_basic.numba_njit + def reciprocal(x): + # TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when + # `x` is an `int` + return 1 / x -@numba_basic.numba_njit -def sigmoid(x): - return 1 / (1 + np.exp(-x)) + return reciprocal @numba_funcify.register(Sigmoid) def numba_funcify_Sigmoid(op, node, **kwargs): - return numba_basic.global_numba_func(sigmoid) - + @numba_basic.numba_njit + def sigmoid(x): + return 1 / (1 + np.exp(-x)) -@numba_basic.numba_njit -def gammaln(x): - return math.lgamma(x) + return sigmoid @numba_funcify.register(GammaLn) def numba_funcify_GammaLn(op, node, **kwargs): - return numba_basic.global_numba_func(gammaln) - + @numba_basic.numba_njit + def gammaln(x): + return math.lgamma(x) -@numba_basic.numba_njit -def logp1mexp(x): - if x < np.log(0.5): - return np.log1p(-np.exp(x)) - else: - return np.log(-np.expm1(x)) + return gammaln @numba_funcify.register(Log1mexp) def numba_funcify_Log1mexp(op, node, **kwargs): - return numba_basic.global_numba_func(logp1mexp) - + @numba_basic.numba_njit + def logp1mexp(x): + if x < np.log(0.5): + return np.log1p(-np.exp(x)) + else: + return np.log(-np.expm1(x)) -@numba_basic.numba_njit -def erf(x): - return math.erf(x) + return logp1mexp @numba_funcify.register(Erf) def numba_funcify_Erf(op, **kwargs): - return numba_basic.global_numba_func(erf) - + @numba_basic.numba_njit + def erf(x): + return math.erf(x) -@numba_basic.numba_njit -def erfc(x): - return math.erfc(x) + return erf @numba_funcify.register(Erfc) def numba_funcify_Erfc(op, **kwargs): - return numba_basic.global_numba_func(erfc) + @numba_basic.numba_njit + def erfc(x): + return math.erfc(x) + + return erfc @numba_funcify.register(Softplus)