Hi @Hourglass, currently, we allow to pass in mcmc.run(..., init_params=init_params)
which works for multiple chains, but those parameters lie in unconstrained domains. You can convert samples from the constrained domain to the unconstrained domain by using something like
init_params = jax.vmap(lambda p: unconstrain_fn(model, model_args, model_kwargs, p))(init_values)
The unconstrain_fn
is not available yet. Please express your interest in Inverse bjiector transformation (from constrained to unconstrained space) · Issue #1554 · pyro-ppl/numpyro · GitHub .
Maybe we can add documentation to MCMC class to illustrate this usage?