NestSampler: is it the right way to proceed?

Dear experts,

I have a model which I use with the NUTS sampler. Schematically I do

  1. define my model using numpyro.sample statements to deifne priors and likelihood
  2. I use fix_cond_model = numpyro.handlers.condition(model, <parmeters defult values> to generate some data thanks to
tr = numpyro.handlers.trace(seed(fix_cond_model, rng_key))
res = tr.get_trace()
data = res['data']['value']

Then, I proceed to the MCMC run et finally get the samples:

cond_model = numpyro.handlers.condition(model, {'data':data})
nuts_kernel = numpyro.infer.NUTS(cond_model,
                                 max_tree_depth=2) # max_tree_depth=10

rng_key = jax.random.PRNGKey(0)

mcmc = numpyro.infer.MCMC(nuts_kernel, 
                          progress_bar=True), extra_fields=('potential_energy',))
samples= mcmc.get_samples()

Well, now I wander if I can use the NestedSampler? I have tried

cond_model = numpyro.handlers.condition(model, {'data':data})
ns = NestedSampler(cond_model)
samples= ns.get_samples(jax.random.PRNGKey(3),num_samples=10_000)

But the sampling of the variables is clearly pathologic

I certainly miss something. Any idea are wellcome. Thanks

Your approach seems right for numpyro models. Probably the result is the nature of nested sampler? You can try to use plot diagnostics method to get some infos.

Heuu, I’m quite new in NestedSampling and if you have advises (tools) for diagnostics that would be great. Thanks.

How about using print_summary and diagnostics? You might find some tutorials here. Also, you can try to tune parameters like num_live_points, depth,… I’m not familiar with nested sampling too. NestedSampler class is just a wrapper of jaxns, to be used for numpyro models.

Thanks @fehiepsi I will try and let you know.