diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 86029e626f..d08759a978 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -385,46 +385,56 @@ def c_support_code(self, **kwargs): #define DEVICE #endif - #ifndef ga_double - #define ga_double double + #ifndef M_PI + #define M_PI 3.14159265358979323846 #endif #ifndef _PSIFUNCDEFINED #define _PSIFUNCDEFINED - DEVICE double _psi(ga_double x) { - - /*taken from - Bernardo, J. M. (1976). Algorithm AS 103: - Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317. - http://www.uv.es/~bernardo/1976AppStatist.pdf */ + DEVICE double _psi(double x) { + + /*taken from + Bernardo, J. M. (1976). Algorithm AS 103: + Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317. + http://www.uv.es/~bernardo/1976AppStatist.pdf + */ + + double y, R, psi_ = 0; + double S = 1.0e-5; + double C = 8.5; + double S3 = 8.333333333e-2; + double S4 = 8.333333333e-3; + double S5 = 3.968253968e-3; + double D1 = -0.5772156649; - ga_double y, R, psi_ = 0; - ga_double S = 1.0e-5; - ga_double C = 8.5; - ga_double S3 = 8.333333333e-2; - ga_double S4 = 8.333333333e-3; - ga_double S5 = 3.968253968e-3; - ga_double D1 = -0.5772156649; - - y = x; + if (x <= 0) { + // the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero + if (x == floor(x)) { + return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers + } + + // Use reflection formula + double pi_x = M_PI * x; + double cot_pi_x = cos(pi_x) / sin(pi_x); + return _psi(1.0 - x) - M_PI * cot_pi_x; + } - if (y <= 0.0) - return psi_; + y = x; - if (y <= S) - return D1 - 1.0/y; + if (y <= S) + return D1 - 1.0/y; - while (y < C) { - psi_ = psi_ - 1.0 / y; - y = y + 1; - } + while (y < C) { + psi_ = psi_ - 1.0 / y; + y = y + 1; + } - R = 1.0 / y; - psi_ = psi_ + log(y) - .5 * R ; - R= R*R; - psi_ = psi_ - R * (S3 - R * (S4 - R * S5)); + R = 1.0 / y; + psi_ = psi_ + log(y) - .5 * R ; + R= R*R; + psi_ = psi_ - R * (S3 - R * (S4 - R * S5)); - return psi_; + return psi_; } #endif """ @@ -433,8 +443,8 @@ def c_code(self, node, name, inp, out, sub): (x,) = inp (z,) = out if node.inputs[0].type in float_types: - return f"""{z} = - _psi({x});""" + dtype = "npy_" + node.outputs[0].dtype + return f"{z} = ({dtype}) _psi({x});" raise NotImplementedError("only floating point is implemented") diff --git a/tests/scalar/test_math.py b/tests/scalar/test_math.py index f4a9f2d414..da116ab887 100644 --- a/tests/scalar/test_math.py +++ b/tests/scalar/test_math.py @@ -2,6 +2,7 @@ import numpy as np import pytest +import scipy import scipy.special as sp import pytensor.tensor as pt @@ -19,6 +20,7 @@ gammal, gammau, hyp2f1, + psi, ) from tests.link.test_link import make_function @@ -149,3 +151,20 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads): (var.owner and isinstance(var.owner.op, ScalarLoop)) for var in ancestors(grad) ) + + +@pytest.mark.parametrize( + "linker", + ["py", "cvm"], +) +def test_psi(linker): + x = float64("x") + out = psi(x) + + fn = function([x], out, mode=Mode(linker=linker, optimizer="fast_run")) + fn.dprint() + + x_test = np.float64(0.7) + + np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test)) + np.testing.assert_allclose(fn(-x_test), scipy.special.psi(-x_test))