Hi,
Let us consider this simple user-case
import jax
import numpy as np
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import NUTS, HMC, MCMC
param_true = np.array([1.0, 0.0, 0.2, 0.5, 1.5])
sample_size = 5_000
sigma_e = param_true[4] # true value of parameter error sigma
random_num_generator = np.random.RandomState(0)
xi = 5*random_num_generator.rand(sample_size)-2.5
e = random_num_generator.normal(0, sigma_e, sample_size)
yi = param_true[0] + param_true[1] * xi + param_true[2] * xi**2 + param_true[3] *xi**3 + e
def my_model(Xspls,Yspls):
a0 = numpyro.sample('a0', dist.Normal(0.,10))
a1 = numpyro.sample('a1', dist.Normal(0.,10))
a2 = numpyro.sample('a2', dist.Normal(0.,10))
a3 = numpyro.sample('a3', dist.Normal(0.,10))
sigma = numpyro.sample('sigma', dist.Uniform(low=0.,high=10.))
mu = a0 + a1*Xspls + a2*Xspls**2 + a3*Xspls**3
numpyro.sample('obs', dist.Normal(mu, sigma), obs=Yspls)
# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = jax.random.PRNGKey(0)
_, rng_key = jax.random.split(rng_key)
# Run NUTS.
kernel = NUTS(my_model)
num_samples = 5_000
mcmc = MCMC(kernel, num_warmup=1_000, num_samples=num_samples,
num_chains=20,progress_bar=False)
mcmc.run(rng_key, Xspls=xi, Yspls=yi)
mcmc.print_summary()
samples_1 = mcmc.get_samples()
I wander how are started the 20 chains? does the initialization is infered from the my_model
function by using the a-paramters and sigma-parameter priors? Or do I have to introduce some init_params
?