Skip to content

Commit

Permalink
Merge pull request #152 from Hazboun6/master
Browse files Browse the repository at this point in the history
Added test for OS phase shift
  • Loading branch information
Hazboun6 authored Nov 1, 2021
2 parents d19d26b + 247ab57 commit c481a3d
Showing 1 changed file with 40 additions and 1 deletion.
41 changes: 40 additions & 1 deletion tests/test_os.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import numpy as np
import pytest

from enterprise_extensions import models
from enterprise.signals import signal_base, gp_signals, parameter, utils
from enterprise_extensions import models, blocks, model_utils
from enterprise_extensions.frequentist import optimal_statistic as optstat

testdir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -71,3 +72,41 @@ def test_os(nodmx_psrs, pta_model2a):
chain[ii, :] = np.array(entry)
OS.compute_noise_marginalized_os(chain, param_names=OS.pta.param_names, N=10)
OS.compute_noise_maximized_os(chain, param_names=OS.pta.param_names)


@pytest.mark.filterwarnings('ignore::DeprecationWarning')
@pytest.fixture
def pta_pshift(dmx_psrs, caplog):
Tspan = model_utils.get_tspan(dmx_psrs)
tm = gp_signals.TimingModel()
wn = blocks.white_noise_block(inc_ecorr=True)
rn = blocks.red_noise_block(Tspan=Tspan)
pseed = parameter.Uniform(0, 10000)('gw_pseed')
gw_log10_A = parameter.Uniform(-18, -14)('gw_log10_A')
gw_gamma = parameter.Constant(13./3)('gw_gamma')
gw_pl = utils.powerlaw(log10_A=gw_log10_A, gamma=gw_gamma)
gw_pshift = gp_signals.FourierBasisGP(spectrum=gw_pl,
components=5,
Tspan=Tspan,
name='gw',
pshift=True,
pseed=pseed)
model = tm + wn + rn + gw_pshift
pta_pshift = signal_base.PTA([model(p) for p in dmx_psrs])
pta_pshift.set_default_params(noise_dict)
return pta_pshift


@pytest.mark.filterwarnings('ignore::DeprecationWarning')
def test_os_pseed(dmx_psrs, pta_pshift):
OS = optstat.OptimalStatistic(psrs=dmx_psrs, pta=pta_pshift)
params = {pnm: p.sample() for pnm, p in zip(pta_pshift.param_names,
pta_pshift.params)}
params.update({'gw_pseed': 1})
_, _, _, A1, rho1 = OS.compute_os(params=params)
params.update({'gw_pseed': 2})
_, _, _, A2, rho2 = OS.compute_os(params=params)
print(A1, A2)
print(rho1, rho2)
assert A1!=A2
assert rho1!=rho2

0 comments on commit c481a3d

Please sign in to comment.