Skip to content

Commit

Permalink
Use math.lgamma to allow numba caching
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Jan 20, 2025
1 parent eb26ec1 commit b03e27b
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 101 deletions.
14 changes: 7 additions & 7 deletions tests/exact_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import scipy
from scipy.special import betaln
from scipy.special import gammaln
from math import lgamma


def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
Expand All @@ -33,7 +33,7 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
s2 = s1 * (a + 1) * (b + 1) / (c + 1)
d1 = s1 * exp(f1 - f0)
d2 = s2 * exp(f2 - f0)
logl = f0 + betaln(y_ij + 1, a) + gammaln(b) - b * log(t)
logl = f0 + betaln(y_ij + 1, a) + lgamma(b) - b * log(t)
mn_j = d1 / t
sq_j = d2 / t**2
va_j = sq_j - mn_j**2
Expand All @@ -56,7 +56,7 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij):
b = s + 1
z = t_j * r
if t_j == 0.0:
logl = gammaln(s) - s * log(r)
logl = lgamma(s) - s * log(r)
mn_i = s / r
va_i = s / r**2
return logl, mn_i, va_i
Expand All @@ -65,7 +65,7 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij):
f2 = float(mpmath.log(mpmath.hyperu(a + 2, b + 2, z)))
d0 = -a * exp(f1 - f0)
d1 = -(a + 1) * exp(f2 - f1)
logl = f0 - b_i * t_j + (b - 1) * log(t_j) + gammaln(a)
logl = f0 - b_i * t_j + (b - 1) * log(t_j) + lgamma(a)
mn_i = t_j * (1 - d0)
va_i = t_j**2 * d0 * (d1 - d0)
return logl, mn_i, va_i
Expand Down Expand Up @@ -112,7 +112,7 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
s2 = s1 * (a + 1) * (b + 1) / (c + 1)
d1 = s1 * exp(f1 - f0)
d2 = s2 * exp(f2 - f0)
logl = f0 + betaln(a_j, a_i) + gammaln(b) - b * log(t)
logl = f0 + betaln(a_j, a_i) + lgamma(b) - b * log(t)
mn_j = d1 / t
sq_j = d2 / t**2
va_j = sq_j - mn_j**2
Expand All @@ -130,7 +130,7 @@ def twin_moments(a_i, b_i, y_ij, mu_ij):
"""
s = a_i + y_ij
r = b_i + 2 * mu_ij
logl = log(2) * y_ij + gammaln(s) - log(r) * s
logl = log(2) * y_ij + lgamma(s) - log(r) * s
mn_i = s / r
va_i = s / r**2
return logl, mn_i, va_i
Expand All @@ -151,7 +151,7 @@ def sideways_moments(t_i, a_j, b_j, y_ij, mu_ij):
f2 = float(mpmath.log(mpmath.hyperu(a + 2, b + 2, z)))
d0 = -a * exp(f1 - f0)
d1 = -(a + 1) * exp(f2 - f1)
logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + gammaln(a)
logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + lgamma(a)
mn_j = -t_i * d0
va_j = t_i**2 * d0 * (d1 - d0)
return logl, mn_j, va_j
Expand Down
7 changes: 5 additions & 2 deletions tests/test_hypergeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
Test cases for numba-fied hypergeometric functions
"""

from math import lgamma

import mpmath
import numdifftools as nd
import numpy as np
Expand All @@ -38,8 +40,8 @@ class TestPolygamma:
Test numba-fied gamma functions
"""

def test_gammaln(self, x):
assert np.isclose(hypergeo._gammaln(x), float(mpmath.re(mpmath.loggamma(x))))
def test_lgamma(self, x):
assert np.isclose(lgamma(x), float(mpmath.re(mpmath.loggamma(x))))

def test_digamma(self, x):
assert np.isclose(hypergeo._digamma(x), float(mpmath.psi(0, x)))
Expand Down Expand Up @@ -120,6 +122,7 @@ def _2f1_validate(a_i, b_i, a_j, b_j, y, mu, offset=1.0):
val = mpmath.re(mpmath.hyp2f1(A, B, C, z, maxterms=1e7))
return val / offset

@pytest.mark.skip(reason="_hyp2f1_unity now an inner function for numba")
def test_2f1(self, pars):
a_i, b_i, a_j, b_j, y, mu = pars
A = a_j
Expand Down
3 changes: 2 additions & 1 deletion tsdate/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from numba import jit

# By default we disable the numba cache. See
# By default we disable the numba cache. See e.g.
# https://github.com/sgkit-dev/sgkit/blob/main/sgkit/accelerate.py
_DISABLE_CACHE = os.environ.get("TSDATE_DISABLE_NUMBA_CACHE", "1")

try:
Expand Down
56 changes: 28 additions & 28 deletions tsdate/approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _valid_hyp2f1(a, b, c, z):
# --- various EP updates --- #


@numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f))
def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_i, t_j) := \
Expand Down Expand Up @@ -277,7 +277,7 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
d1 = s1 * exp(f1 - f0)
d2 = s2 * exp(f2 - f0)

logl = f0 + hypergeo._betaln(y_ij + 1, a) + hypergeo._gammaln(b) - b * log(t)
logl = f0 + hypergeo._betaln(y_ij + 1, a) + lgamma(b) - b * log(t)

mn_j = d1 / t
sq_j = d2 / t**2
Expand All @@ -290,7 +290,7 @@ def moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
return logl, mn_i, va_i, mn_j, va_j


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij):
r"""
log p(t_i) := \
Expand All @@ -309,7 +309,7 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij):
return nan, nan, nan

if t_j == 0.0:
logl = hypergeo._gammaln(s) - s * log(r)
logl = lgamma(s) - s * log(r)
mn_i = s / r
va_i = s / r**2
return logl, mn_i, va_i
Expand All @@ -325,14 +325,14 @@ def rootward_moments(t_j, a_i, b_i, y_ij, mu_ij):
f0, d0 = hyperu(a + 0, b + 0, z)
f1, d1 = hyperu(a + 1, b + 1, z)

logl = f0 - b_i * t_j + (b - 1) * log(t_j) + hypergeo._gammaln(a)
logl = f0 - b_i * t_j + (b - 1) * log(t_j) + lgamma(a)
mn_i = t_j * (1 - d0)
va_i = t_j**2 * d0 * (d1 - d0)

return logl, mn_i, va_i


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_j) := \
Expand Down Expand Up @@ -366,7 +366,7 @@ def leafward_moments(t_i, a_j, b_j, y_ij, mu_ij):
return logl, mn_j, va_j


@numba.njit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 5)(_f, _f, _f, _f, _f, _f))
def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_i, t_j) := \
Expand Down Expand Up @@ -395,7 +395,7 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
d1 = s1 * exp(f1 - f0)
d2 = s2 * exp(f2 - f0)

logl = f0 + hypergeo._betaln(a_j, a_i) + hypergeo._gammaln(b) - b * log(t)
logl = f0 + hypergeo._betaln(a_j, a_i) + lgamma(b) - b * log(t)

mn_j = d1 / t
sq_j = d2 / t**2
Expand All @@ -408,7 +408,7 @@ def unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
return logl, mn_i, va_i, mn_j, va_j


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f))
def twin_moments(a_i, b_i, y_ij, mu_ij):
r"""
log p(t_i) := \
Expand All @@ -419,13 +419,13 @@ def twin_moments(a_i, b_i, y_ij, mu_ij):
"""
s = a_i + y_ij
r = b_i + 2 * mu_ij
logl = log(2) * y_ij + hypergeo._gammaln(s) - log(r) * s
logl = log(2) * y_ij + lgamma(s) - log(r) * s
mn_i = s / r
va_i = s / r**2
return logl, mn_i, va_i


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
def sideways_moments(t_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_j) := \
Expand All @@ -448,14 +448,14 @@ def sideways_moments(t_i, a_j, b_j, y_ij, mu_ij):
f0, d0 = hyperu(a + 0, b + 0, z)
f1, d1 = hyperu(a + 1, b + 1, z)

logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + hypergeo._gammaln(a)
logl = f0 - mu_ij * t_i + (b - 1) * log(t_i) + lgamma(a)
mn_j = -t_i * d0
va_j = t_i**2 * d0 * (d1 - d0)

return logl, mn_j, va_j


@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 2)(_f, _f, _f, _f, _f, _f))
def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_m, t_i, t_j) = \
Expand Down Expand Up @@ -497,7 +497,7 @@ def mutation_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
return mn_m, va_m


@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 2)(_f, _f, _f, _f, _f))
def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij):
r"""
log p(t_m, t_i) := \
Expand All @@ -516,7 +516,7 @@ def mutation_rootward_moments(t_j, a_i, b_i, y_ij, mu_ij):
return mn_m, va_m


@numba.njit(_unituple(_f, 2)(_f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 2)(_f, _f, _f, _f, _f))
def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_m, t_j) := \
Expand All @@ -535,7 +535,7 @@ def mutation_leafward_moments(t_i, a_j, b_j, y_ij, mu_ij):
return mn_m, va_m


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f, _f))
def mutation_unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_m, t_i, t_j) := \
Expand Down Expand Up @@ -588,7 +588,7 @@ def mutation_unphased_moments(a_i, b_i, a_j, b_j, y_ij, mu_ij):
return pr_m, mn_m, va_m


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f))
def mutation_twin_moments(a_i, b_i, y_ij, mu_ij):
r"""
log p(t_m, t_i) := \
Expand All @@ -607,7 +607,7 @@ def mutation_twin_moments(a_i, b_i, y_ij, mu_ij):
return pr_m, mn_m, va_m


@numba.njit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
@numba_jit(_unituple(_f, 3)(_f, _f, _f, _f, _f))
def mutation_sideways_moments(t_i, a_j, b_j, y_ij, mu_ij):
r"""
log p(t_m, t_j) := \
Expand Down Expand Up @@ -695,7 +695,7 @@ def mutation_block_moments(t_i, t_j):
# --- wrappers around updates --- #


@numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r))
def gamma_projection(pars_i, pars_j, pars_ij):
r"""
log p(t_i, t_j) := \
Expand All @@ -722,7 +722,7 @@ def gamma_projection(pars_i, pars_j, pars_ij):
return logl, np.array(proj_i), np.array(proj_j)


@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
def leafward_projection(t_i, pars_j, pars_ij):
r"""
log p(t_j) := \
Expand All @@ -745,7 +745,7 @@ def leafward_projection(t_i, pars_j, pars_ij):
return logl, np.array(proj_j)


@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
def rootward_projection(t_j, pars_i, pars_ij):
r"""
log p(t_i) := \
Expand All @@ -768,7 +768,7 @@ def rootward_projection(t_j, pars_i, pars_ij):
return logl, np.array(proj_i)


@numba.njit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r, _f1r))(_f1r, _f1r, _f1r))
def unphased_projection(pars_i, pars_j, pars_ij):
r"""
log p(t_i, t_j) := \
Expand All @@ -795,7 +795,7 @@ def unphased_projection(pars_i, pars_j, pars_ij):
return logl, np.array(proj_i), np.array(proj_j)


@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f1r, _f1r))
def twin_projection(pars_i, pars_ij):
r"""
log p(t_i) := \
Expand All @@ -818,7 +818,7 @@ def twin_projection(pars_i, pars_ij):
return logl, np.array(proj_i)


@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
def sideways_projection(t_i, pars_j, pars_ij):
r"""
log p(t_j) := \
Expand All @@ -841,7 +841,7 @@ def sideways_projection(t_i, pars_j, pars_ij):
return logl, np.array(proj_j)


@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r))
def mutation_gamma_projection(pars_i, pars_j, pars_ij):
r"""
log p(t_m, t_i, t_j) = \
Expand All @@ -868,7 +868,7 @@ def mutation_gamma_projection(pars_i, pars_j, pars_ij):
return 1.0, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
def mutation_leafward_projection(t_i, pars_j, pars_ij):
r"""
log p(t_m, t_j) := \
Expand All @@ -892,7 +892,7 @@ def mutation_leafward_projection(t_i, pars_j, pars_ij):
return 1.0, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f, _f1r, _f1r))
def mutation_rootward_projection(t_j, pars_i, pars_ij):
r"""
log p(t_m, t_i) := \
Expand Down Expand Up @@ -934,7 +934,7 @@ def mutation_edge_projection(t_i, t_j):
return 1.0, np.array(proj_m)


@numba.njit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r))
@numba_jit(_tuple((_f, _f1r))(_f1r, _f1r, _f1r))
def mutation_unphased_projection(pars_i, pars_j, pars_ij):
r"""
log p(t_m, t_i, t_j) := \
Expand Down
Loading

0 comments on commit b03e27b

Please sign in to comment.