Skip to content

Commit

Permalink
Merge pull request #143 from paulthebaker/ptb-hyper-info
Browse files Browse the repository at this point in the history
add `HyperModel.summary`
  • Loading branch information
Hazboun6 authored Oct 4, 2021
2 parents cca7d23 + 70c89fd commit 9ee66f2
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 3 deletions.
3 changes: 3 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
=======
History
=======
2.3.2 (2021-10-04)
Fix bug in HyperModel when using save_runtime_info.

2.3.1 (2021-09-30)
Fix bugs associated with recent function additions. Added linting and mild PEP8
rules. Also removed older Python functionality which is no longer supported.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Remember to cite it as:
title = {enterprise_extensions},
year = {2021},
url = {https://github.com/nanograv/enterprise_extensions},
note = {v2.3.1}
note = {v2.3.2}
}
```

Expand Down
2 changes: 1 addition & 1 deletion enterprise_extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.3.1"
__version__ = "2.3.2"
19 changes: 19 additions & 0 deletions enterprise_extensions/hypermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,22 @@ def get_process_timeseries(self, psr, chain, burn, comp='DM',
ret = wave

return ret

def summary(self, to_stdout=False):
"""generate summary string for HyperModel, including all PTAs
:param to_stdout: [bool]
print summary to `stdout` instead of returning it
:return: [string]
"""

summary = ""
for ii, pta in self.models.items():
summary += "model " + str(ii) + "\n"
summary += "=" * 9 + "\n\n"
summary += pta.summary()
summary += "=" * 90 + "\n\n"
if to_stdout:
print(summary)
else:
return summary
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 2.3.1
current_version = 2.3.2
commit = True
tag = True

Expand Down
66 changes: 66 additions & 0 deletions tests/test_hypermodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Tests for `enterprise_extensions` package."""

import json
import logging
import os
import pickle

import pytest

from enterprise_extensions import models, hypermodel

testdir = os.path.dirname(os.path.abspath(__file__))
datadir = os.path.join(testdir, 'data')
outdir = os.path.join(testdir, 'test_out')

psr_names = ['J0613-0200', 'J1713+0747', 'J1909-3744']

with open(datadir+'/ng11yr_noise.json', 'r') as fin:
noise_dict = json.load(fin)


@pytest.fixture
def dmx_psrs(caplog):
"""Sample pytest fixture.
See more at: http://doc.pytest.org/en/latest/fixture.html
"""
caplog.set_level(logging.CRITICAL)
psrs = []
for p in psr_names:
with open(datadir+'/{0}_ng9yr_dmx_DE436_epsr.pkl'.format(p), 'rb') as fin:
psrs.append(pickle.load(fin))

return psrs


@pytest.mark.filterwarnings('ignore::DeprecationWarning')
def test_hypermodel(dmx_psrs, caplog):
m2a = models.model_2a(dmx_psrs, noisedict=noise_dict)
m3a = models.model_3a(dmx_psrs, noisedict=noise_dict)
ptas = {0: m2a, 1: m3a}
hm = hypermodel.HyperModel(ptas)
assert hasattr(hm, 'get_lnlikelihood')
assert 'gw_log10_A' in hm.param_names
assert 'nmodel' in hm.param_names


@pytest.mark.filterwarnings('ignore::DeprecationWarning')
def test_hyper_sampler(dmx_psrs, caplog):
m2a = models.model_2a(dmx_psrs, noisedict=noise_dict)
m3a = models.model_3a(dmx_psrs, noisedict=noise_dict)
ptas = {0: m2a, 1: m3a}
hm = hypermodel.HyperModel(ptas)
samp = hm.setup_sampler(outdir=outdir, human='tester')
assert hasattr(samp, "sample")
paramfile = os.path.join(outdir, "pars.txt")
assert os.path.isfile(paramfile)
with open(paramfile, "r") as f:
params = [line.rstrip('\n') for line in f]
for ptapar, filepar in zip(hm.param_names, params):
assert ptapar == filepar
x0 = hm.initial_sample()
assert len(x0) == len(hm.param_names)

0 comments on commit 9ee66f2

Please sign in to comment.