Errors in jaxns model when installing via PyCharm

Hi All

I’ve recently moved to a new laptop / python installation and am trying to run the Gaussian Shells example for JAXNS as a test case. I’m working in PyCharm and using its package installation GUI.

To summarize my issue:

  • JAXNS itself working fine (dual moons example working perfectly well)
  • Gaussian Shells example for numpyro integration not working

There seems to be something going askew with this installation process. In the numpyro.contrib.nested_sampling code I can clearly see an import statement for from jaxns.public import DefaultNestedSampler, but my installation is complaining about a bad import for jaxns.DefaultNestedSampler, i.e. not from the public module. The former line works when tested in isolation. I’m similarly getting some errors that are inconsistent with the up-to-date source code on github, e.g. the ImportError not string being slightly different. I can confirm that I am running numpyro version 0.15.3 and jaxns version 2.6.3.

Does anyone have any ideas why pip-installing should produce such inconsistencies with the content on github?

Could you install from dev?

git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
pip install -e .[dev]  # contains additional dependencies for NumPyro development

jaxns api changes over time, so it’s tricky to sync. We tried to update it occasionally to avoid breakage. A solution is to release a new numpyro version after upgrading jaxns - which is not desirable. :frowning: