@@ -141,17 +141,16 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
141
141
)(scalar_op_fn )
142
142
143
143
144
- @numba_basic .numba_njit
145
- def switch (condition , x , y ):
146
- if condition :
147
- return x
148
- else :
149
- return y
150
-
151
-
152
144
@numba_funcify .register (Switch )
153
145
def numba_funcify_Switch (op , node , ** kwargs ):
154
- return numba_basic .global_numba_func (switch )
146
+ @numba_basic .numba_njit
147
+ def switch (condition , x , y ):
148
+ if condition :
149
+ return x
150
+ else :
151
+ return y
152
+
153
+ return switch
155
154
156
155
157
156
def binary_to_nary_func (inputs : list [Variable ], binary_op_name : str , binary_op : str ):
@@ -197,34 +196,32 @@ def cast(x):
197
196
return cast
198
197
199
198
200
- @numba_basic .numba_njit
201
- def identity (x ):
202
- return x
203
-
204
-
205
199
@numba_funcify .register (Identity )
206
200
@numba_funcify .register (TypeCastingOp )
207
201
def numba_funcify_type_casting (op , ** kwargs ):
208
- return numba_basic .global_numba_func (identity )
209
-
210
-
211
- @numba_basic .numba_njit
212
- def clip (_x , _min , _max ):
213
- x = numba_basic .to_scalar (_x )
214
- _min_scalar = numba_basic .to_scalar (_min )
215
- _max_scalar = numba_basic .to_scalar (_max )
216
-
217
- if x < _min_scalar :
218
- return _min_scalar
219
- elif x > _max_scalar :
220
- return _max_scalar
221
- else :
202
+ @numba_basic .numba_njit
203
+ def identity (x ):
222
204
return x
223
205
206
+ return identity
207
+
224
208
225
209
@numba_funcify .register (Clip )
226
210
def numba_funcify_Clip (op , ** kwargs ):
227
- return numba_basic .global_numba_func (clip )
211
+ @numba_basic .numba_njit
212
+ def clip (x , min_val , max_val ):
213
+ x = numba_basic .to_scalar (x )
214
+ min_scalar = numba_basic .to_scalar (min_val )
215
+ max_scalar = numba_basic .to_scalar (max_val )
216
+
217
+ if x < min_scalar :
218
+ return min_scalar
219
+ elif x > max_scalar :
220
+ return max_scalar
221
+ else :
222
+ return x
223
+
224
+ return clip
228
225
229
226
230
227
@numba_funcify .register (Composite )
@@ -239,14 +236,13 @@ def numba_funcify_Composite(op, node, **kwargs):
239
236
return composite_fn
240
237
241
238
242
- @numba_basic .numba_njit
243
- def second (x , y ):
244
- return y
245
-
246
-
247
239
@numba_funcify .register (Second )
248
240
def numba_funcify_Second (op , node , ** kwargs ):
249
- return numba_basic .global_numba_func (second )
241
+ @numba_basic .numba_njit
242
+ def second (x , y ):
243
+ return y
244
+
245
+ return second
250
246
251
247
252
248
@numba_basic .numba_njit
@@ -258,60 +254,61 @@ def reciprocal(x):
258
254
259
255
@numba_funcify .register (Reciprocal )
260
256
def numba_funcify_Reciprocal (op , node , ** kwargs ):
261
- return numba_basic .global_numba_func (reciprocal )
262
-
257
+ @numba_basic .numba_njit
258
+ def reciprocal (x ):
259
+ # TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
260
+ # `x` is an `int`
261
+ return 1 / x
263
262
264
- @numba_basic .numba_njit
265
- def sigmoid (x ):
266
- return 1 / (1 + np .exp (- x ))
263
+ return reciprocal
267
264
268
265
269
266
@numba_funcify .register (Sigmoid )
270
267
def numba_funcify_Sigmoid (op , node , ** kwargs ):
271
- return numba_basic .global_numba_func (sigmoid )
272
-
268
+ @numba_basic .numba_njit
269
+ def sigmoid (x ):
270
+ return 1 / (1 + np .exp (- x ))
273
271
274
- @numba_basic .numba_njit
275
- def gammaln (x ):
276
- return math .lgamma (x )
272
+ return sigmoid
277
273
278
274
279
275
@numba_funcify .register (GammaLn )
280
276
def numba_funcify_GammaLn (op , node , ** kwargs ):
281
- return numba_basic .global_numba_func (gammaln )
277
+ @numba_basic .numba_njit
278
+ def gammaln (x ):
279
+ return math .lgamma (x )
282
280
283
-
284
- @numba_basic .numba_njit
285
- def logp1mexp (x ):
286
- if x < np .log (0.5 ):
287
- return np .log1p (- np .exp (x ))
288
- else :
289
- return np .log (- np .expm1 (x ))
281
+ return gammaln
290
282
291
283
292
284
@numba_funcify .register (Log1mexp )
293
285
def numba_funcify_Log1mexp (op , node , ** kwargs ):
294
- return numba_basic .global_numba_func (logp1mexp )
295
-
286
+ @numba_basic .numba_njit
287
+ def logp1mexp (x ):
288
+ if x < np .log (0.5 ):
289
+ return np .log1p (- np .exp (x ))
290
+ else :
291
+ return np .log (- np .expm1 (x ))
296
292
297
- @numba_basic .numba_njit
298
- def erf (x ):
299
- return math .erf (x )
293
+ return logp1mexp
300
294
301
295
302
296
@numba_funcify .register (Erf )
303
297
def numba_funcify_Erf (op , ** kwargs ):
304
- return numba_basic .global_numba_func (erf )
305
-
298
+ @numba_basic .numba_njit
299
+ def erf (x ):
300
+ return math .erf (x )
306
301
307
- @numba_basic .numba_njit
308
- def erfc (x ):
309
- return math .erfc (x )
302
+ return erf
310
303
311
304
312
305
@numba_funcify .register (Erfc )
313
306
def numba_funcify_Erfc (op , ** kwargs ):
314
- return numba_basic .global_numba_func (erfc )
307
+ @numba_basic .numba_njit
308
+ def erfc (x ):
309
+ return math .erfc (x )
310
+
311
+ return erfc
315
312
316
313
317
314
@numba_funcify .register (Softplus )
0 commit comments