Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError in GP modeling with jaxns integration #812

Open
fcotizelati opened this issue Mar 27, 2024 · 4 comments
Open

AttributeError in GP modeling with jaxns integration #812

fcotizelati opened this issue Mar 27, 2024 · 4 comments

Comments

@fcotizelati
Copy link

While trying to perform GP modeling using the GPResult.sample function integrated with jaxns for nested sampling, I encountered the following AttributeError:

AttributeError: 'ArrayImpl' object has no attribute 'next_sample_idx'

It seems to me this is due to the fact that the state object returned from the nested sampling process executed by exact_ns apparently lacks the expected attribute next_sample_idx?

Here are the versions of Python and the involved libraries I am using:

  • Python version: 3.11.8
  • stingray version: 2.0.0
  • jax version: 0.4.23
  • jaxns version: 2.4.12
  • tensorflow_probability version: 0.22.1

Thanks for the help!

@matteobachetti
Copy link
Member

Hi @fcotizelati, thanks for reporting the issue
@Gaurav17Joshi, @dhuppenkothen : any ideas?

@matteobachetti
Copy link
Member

@fcotizelati, sorry for letting this slip. Did you find a fix by any chance?

@fcotizelati
Copy link
Author

fcotizelati commented Feb 24, 2025

Hello,

No problem! The short answer is that unfortunately I could not fix the issue.

I tried following the steps described at the tutorial from scratch and encountered a new issue when using the get_prior function in gpmodelling.py. Although I have jaxns version 2.6.7 installed and can import its modules normally, when I run prior_model = get_prior(params_list, prior_dict) I get the following error:

"ImportError: Jaxns not installed. Cannot make jaxns specific prior."

It seems this is related to the fact that jaxns v2.6.7 does not provide ExactNestedSampler. However, there is a NestedSampler class, and I was able to work around the issue locally by changing the import at lines 23-29 in gpmodelling.py to:

try:
    from jaxns import NestedSampler as ExactNestedSampler, TerminationCondition, Prior, Model
    from jaxns.utils import resample
    can_sample = True
except ImportError:
    can_sample = False

While this change allowed get_prior to work correctly, the sampler’s state returned is a JAX array that does not have the expected attributes. When I run gpresult.sample(prior_model = prior_model, likelihood_model = log_likelihood_model) I get the following error:

AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'num_samples'.

This seems to be an attribute issue similar to the one I reported in my original message, and I'm not sure how to fix it.

Here are the versions of Python and the involved libraries I am using:

  • Python version: 3.11.11
  • Stingray version: 2.2.6
  • Jax version: 0.5.0
  • Jaxns version: 2.6.7

@matteobachetti
Copy link
Member

Hi @fcotizelati, thanks! I hope these new issues will be solved in #832 (see comment #832 (review))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants