Dear experts,
I have a model
which I use with the NUTS
sampler. Schematically I do
- define my model using
numpyro.sample
statements to deifne priors and likelihood - I use
fix_cond_model = numpyro.handlers.condition(model, <parmeters defult values>
to generate somedata
thanks to
tr = numpyro.handlers.trace(seed(fix_cond_model, rng_key))
res = tr.get_trace()
data = res['data']['value']
Then, I proceed to the MCMC run et finally get the samples:
cond_model = numpyro.handlers.condition(model, {'data':data})
nuts_kernel = numpyro.infer.NUTS(cond_model,
init_strategy=numpyro.infer.init_to_sample(),
max_tree_depth=2) # max_tree_depth=10
rng_key = jax.random.PRNGKey(0)
mcmc = numpyro.infer.MCMC(nuts_kernel,
num_warmup=1000,
num_samples=10000,
num_chains=1,
jit_model_args=True,
progress_bar=True)
mcmc.run(rng_key, extra_fields=('potential_energy',))
samples= mcmc.get_samples()
Well, now I wander if I can use the NestedSampler
? I have tried
cond_model = numpyro.handlers.condition(model, {'data':data})
ns = NestedSampler(cond_model)
ns.run(jax.random.PRNGKey(42))
samples= ns.get_samples(jax.random.PRNGKey(3),num_samples=10_000)
But the sampling of the variables is clearly pathologic
I certainly miss something. Any idea are wellcome. Thanks