Skip to content

Extend C implementation of psi to support negative inputs #1523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 42 additions & 32 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,46 +385,56 @@ def c_support_code(self, **kwargs):
#define DEVICE
#endif

#ifndef ga_double
#define ga_double double
#ifndef M_PI
#define M_PI 3.14159265358979323846
#endif

#ifndef _PSIFUNCDEFINED
#define _PSIFUNCDEFINED
DEVICE double _psi(ga_double x) {

/*taken from
Bernardo, J. M. (1976). Algorithm AS 103:
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
http://www.uv.es/~bernardo/1976AppStatist.pdf */
DEVICE double _psi(double x) {

/*taken from
Bernardo, J. M. (1976). Algorithm AS 103:
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
http://www.uv.es/~bernardo/1976AppStatist.pdf
*/

double y, R, psi_ = 0;
double S = 1.0e-5;
double C = 8.5;
double S3 = 8.333333333e-2;
double S4 = 8.333333333e-3;
double S5 = 3.968253968e-3;
double D1 = -0.5772156649;

ga_double y, R, psi_ = 0;
ga_double S = 1.0e-5;
ga_double C = 8.5;
ga_double S3 = 8.333333333e-2;
ga_double S4 = 8.333333333e-3;
ga_double S5 = 3.968253968e-3;
ga_double D1 = -0.5772156649;

y = x;
if (x <= 0) {
// the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
if (x == floor(x)) {
return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
}

// Use reflection formula
double pi_x = M_PI * x;
double cot_pi_x = cos(pi_x) / sin(pi_x);
return _psi(1.0 - x) - M_PI * cot_pi_x;
}

if (y <= 0.0)
return psi_;
y = x;

if (y <= S)
return D1 - 1.0/y;
if (y <= S)
return D1 - 1.0/y;

while (y < C) {
psi_ = psi_ - 1.0 / y;
y = y + 1;
}
while (y < C) {
psi_ = psi_ - 1.0 / y;
y = y + 1;
}

R = 1.0 / y;
psi_ = psi_ + log(y) - .5 * R ;
R= R*R;
psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
R = 1.0 / y;
psi_ = psi_ + log(y) - .5 * R ;
R= R*R;
psi_ = psi_ - R * (S3 - R * (S4 - R * S5));

return psi_;
return psi_;
}
#endif
"""
Expand All @@ -433,8 +443,8 @@ def c_code(self, node, name, inp, out, sub):
(x,) = inp
(z,) = out
if node.inputs[0].type in float_types:
return f"""{z} =
_psi({x});"""
dtype = "npy_" + node.outputs[0].dtype
return f"{z} = ({dtype}) _psi({x});"
raise NotImplementedError("only floating point is implemented")


Expand Down
19 changes: 19 additions & 0 deletions tests/scalar/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
import scipy
import scipy.special as sp

import pytensor.tensor as pt
Expand All @@ -19,6 +20,7 @@
gammal,
gammau,
hyp2f1,
psi,
)
from tests.link.test_link import make_function

Expand Down Expand Up @@ -149,3 +151,20 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads):
(var.owner and isinstance(var.owner.op, ScalarLoop))
for var in ancestors(grad)
)


@pytest.mark.parametrize(
"linker",
["py", "cvm"],
)
def test_psi(linker):
x = float64("x")
out = psi(x)

fn = function([x], out, mode=Mode(linker=linker, optimizer="fast_run"))
fn.dprint()

x_test = np.float64(0.7)

np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test))
np.testing.assert_allclose(fn(-x_test), scipy.special.psi(-x_test))