Skip to content

Commit dcd2576

Browse files
committed
Remove usage of numba_basic.global_numba_func
1 parent 7efd1c5 commit dcd2576

File tree

3 files changed

+71
-92
lines changed

3 files changed

+71
-92
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pytensor.graph.fg import FunctionGraph
2626
from pytensor.graph.type import Type
2727
from pytensor.ifelse import IfElse
28+
from pytensor.link.numba.dispatch import basic as numba_basic
2829
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
2930
from pytensor.link.utils import (
3031
compile_function_src,
@@ -42,14 +43,6 @@
4243
from pytensor.tensor.type_other import MakeSlice, NoneConst
4344

4445

45-
def global_numba_func(func):
46-
"""Use to return global numba functions in numba_funcify_*.
47-
48-
This allows tests to remove the compilation using mock.
49-
"""
50-
return func
51-
52-
5346
def numba_njit(*args, fastmath=None, **kwargs):
5447
kwargs.setdefault("cache", config.numba__cache)
5548
kwargs.setdefault("no_cpython_wrapper", True)
@@ -402,24 +395,22 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
402395
return deepcopyop
403396

404397

405-
@numba_njit
406-
def makeslice(*x):
407-
return slice(*x)
408-
409-
410398
@numba_funcify.register(MakeSlice)
411399
def numba_funcify_MakeSlice(op, **kwargs):
412-
return global_numba_func(makeslice)
400+
@numba_basic.numba_njit
401+
def makeslice(*x):
402+
return slice(*x)
413403

414-
415-
@numba_njit
416-
def shape(x):
417-
return np.asarray(np.shape(x))
404+
return makeslice
418405

419406

420407
@numba_funcify.register(Shape)
421408
def numba_funcify_Shape(op, **kwargs):
422-
return global_numba_func(shape)
409+
@numba_basic.numba_njit
410+
def shape(x):
411+
return np.asarray(np.shape(x))
412+
413+
return shape
423414

424415

425416
@numba_funcify.register(Shape_i)

pytensor/link/numba/dispatch/scalar.py

Lines changed: 61 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,16 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
141141
)(scalar_op_fn)
142142

143143

144-
@numba_basic.numba_njit
145-
def switch(condition, x, y):
146-
if condition:
147-
return x
148-
else:
149-
return y
150-
151-
152144
@numba_funcify.register(Switch)
153145
def numba_funcify_Switch(op, node, **kwargs):
154-
return numba_basic.global_numba_func(switch)
146+
@numba_basic.numba_njit
147+
def switch(condition, x, y):
148+
if condition:
149+
return x
150+
else:
151+
return y
152+
153+
return switch
155154

156155

157156
def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
@@ -197,34 +196,32 @@ def cast(x):
197196
return cast
198197

199198

200-
@numba_basic.numba_njit
201-
def identity(x):
202-
return x
203-
204-
205199
@numba_funcify.register(Identity)
206200
@numba_funcify.register(TypeCastingOp)
207201
def numba_funcify_type_casting(op, **kwargs):
208-
return numba_basic.global_numba_func(identity)
209-
210-
211-
@numba_basic.numba_njit
212-
def clip(_x, _min, _max):
213-
x = numba_basic.to_scalar(_x)
214-
_min_scalar = numba_basic.to_scalar(_min)
215-
_max_scalar = numba_basic.to_scalar(_max)
216-
217-
if x < _min_scalar:
218-
return _min_scalar
219-
elif x > _max_scalar:
220-
return _max_scalar
221-
else:
202+
@numba_basic.numba_njit
203+
def identity(x):
222204
return x
223205

206+
return identity
207+
224208

225209
@numba_funcify.register(Clip)
226210
def numba_funcify_Clip(op, **kwargs):
227-
return numba_basic.global_numba_func(clip)
211+
@numba_basic.numba_njit
212+
def clip(x, min_val, max_val):
213+
x = numba_basic.to_scalar(x)
214+
min_scalar = numba_basic.to_scalar(min_val)
215+
max_scalar = numba_basic.to_scalar(max_val)
216+
217+
if x < min_scalar:
218+
return min_scalar
219+
elif x > max_scalar:
220+
return max_scalar
221+
else:
222+
return x
223+
224+
return clip
228225

229226

230227
@numba_funcify.register(Composite)
@@ -239,14 +236,13 @@ def numba_funcify_Composite(op, node, **kwargs):
239236
return composite_fn
240237

241238

242-
@numba_basic.numba_njit
243-
def second(x, y):
244-
return y
245-
246-
247239
@numba_funcify.register(Second)
248240
def numba_funcify_Second(op, node, **kwargs):
249-
return numba_basic.global_numba_func(second)
241+
@numba_basic.numba_njit
242+
def second(x, y):
243+
return y
244+
245+
return second
250246

251247

252248
@numba_basic.numba_njit
@@ -258,60 +254,61 @@ def reciprocal(x):
258254

259255
@numba_funcify.register(Reciprocal)
260256
def numba_funcify_Reciprocal(op, node, **kwargs):
261-
return numba_basic.global_numba_func(reciprocal)
262-
257+
@numba_basic.numba_njit
258+
def reciprocal(x):
259+
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
260+
# `x` is an `int`
261+
return 1 / x
263262

264-
@numba_basic.numba_njit
265-
def sigmoid(x):
266-
return 1 / (1 + np.exp(-x))
263+
return reciprocal
267264

268265

269266
@numba_funcify.register(Sigmoid)
270267
def numba_funcify_Sigmoid(op, node, **kwargs):
271-
return numba_basic.global_numba_func(sigmoid)
272-
268+
@numba_basic.numba_njit
269+
def sigmoid(x):
270+
return 1 / (1 + np.exp(-x))
273271

274-
@numba_basic.numba_njit
275-
def gammaln(x):
276-
return math.lgamma(x)
272+
return sigmoid
277273

278274

279275
@numba_funcify.register(GammaLn)
280276
def numba_funcify_GammaLn(op, node, **kwargs):
281-
return numba_basic.global_numba_func(gammaln)
277+
@numba_basic.numba_njit
278+
def gammaln(x):
279+
return math.lgamma(x)
282280

283-
284-
@numba_basic.numba_njit
285-
def logp1mexp(x):
286-
if x < np.log(0.5):
287-
return np.log1p(-np.exp(x))
288-
else:
289-
return np.log(-np.expm1(x))
281+
return gammaln
290282

291283

292284
@numba_funcify.register(Log1mexp)
293285
def numba_funcify_Log1mexp(op, node, **kwargs):
294-
return numba_basic.global_numba_func(logp1mexp)
295-
286+
@numba_basic.numba_njit
287+
def logp1mexp(x):
288+
if x < np.log(0.5):
289+
return np.log1p(-np.exp(x))
290+
else:
291+
return np.log(-np.expm1(x))
296292

297-
@numba_basic.numba_njit
298-
def erf(x):
299-
return math.erf(x)
293+
return logp1mexp
300294

301295

302296
@numba_funcify.register(Erf)
303297
def numba_funcify_Erf(op, **kwargs):
304-
return numba_basic.global_numba_func(erf)
305-
298+
@numba_basic.numba_njit
299+
def erf(x):
300+
return math.erf(x)
306301

307-
@numba_basic.numba_njit
308-
def erfc(x):
309-
return math.erfc(x)
302+
return erf
310303

311304

312305
@numba_funcify.register(Erfc)
313306
def numba_funcify_Erfc(op, **kwargs):
314-
return numba_basic.global_numba_func(erfc)
307+
@numba_basic.numba_njit
308+
def erfc(x):
309+
return math.erfc(x)
310+
311+
return erfc
315312

316313

317314
@numba_funcify.register(Softplus)

tests/link/numba/test_basic.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,18 +175,9 @@ def inner_vec(*args):
175175
else:
176176
return wrap
177177

178-
def py_global_numba_func(func):
179-
if hasattr(func, "py_func"):
180-
return func.py_func
181-
return func
182-
183178
mocks = [
184179
mock.patch("numba.njit", njit_noop),
185180
mock.patch("numba.vectorize", vectorize_noop),
186-
mock.patch(
187-
"pytensor.link.numba.dispatch.basic.global_numba_func",
188-
py_global_numba_func,
189-
),
190181
mock.patch(
191182
"pytensor.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem
192183
),

0 commit comments

Comments
 (0)