Chains randomisation?

Let us consider this simple user-case

import jax
import numpy as np

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import NUTS, HMC, MCMC

param_true = np.array([1.0, 0.0, 0.2, 0.5, 1.5])
sample_size = 5_000
sigma_e = param_true[4]          # true value of parameter error sigma
random_num_generator = np.random.RandomState(0)
xi = 5*random_num_generator.rand(sample_size)-2.5
e = random_num_generator.normal(0, sigma_e, sample_size)
yi = param_true[0] + param_true[1] * xi + param_true[2] * xi**2 + param_true[3] *xi**3 +  e 

def my_model(Xspls,Yspls):
    a0 = numpyro.sample('a0', dist.Normal(0.,10))
    a1 = numpyro.sample('a1', dist.Normal(0.,10))
    a2 = numpyro.sample('a2', dist.Normal(0.,10))
    a3 = numpyro.sample('a3', dist.Normal(0.,10))
    sigma = numpyro.sample('sigma', dist.Uniform(low=0.,high=10.))
    mu = a0 + a1*Xspls + a2*Xspls**2 + a3*Xspls**3
    numpyro.sample('obs', dist.Normal(mu, sigma), obs=Yspls)

# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = jax.random.PRNGKey(0)
_, rng_key = jax.random.split(rng_key)

# Run NUTS.
kernel = NUTS(my_model)
num_samples = 5_000
mcmc = MCMC(kernel, num_warmup=1_000, num_samples=num_samples,
           num_chains=20,progress_bar=False), Xspls=xi, Yspls=yi)
samples_1 = mcmc.get_samples()

I wander how are started the 20 chains? does the initialization is infered from the my_model function by using the a-paramters and sigma-parameter priors? Or do I have to introduce some init_params?

see initialization strategies in the docs. the default is init_to_uniform

see here for example usage of different initialization strategies.

Thanks @martinjankowiak may be I need more from you please, if I use

rng_key = jax.random.PRNGKey(0)
tmp= numpyro.infer.util.find_valid_initial_params(rng_key,my_model,model_args=(xi,yi))

I get tmp[0][0]

{'a0': DeviceArray(-1.6553521, dtype=float32),
 'a1': DeviceArray(0.16568518, dtype=float32),
 'a2': DeviceArray(0.24631977, dtype=float32),
 'a3': DeviceArray(0.7033801, dtype=float32),
 'sigma': DeviceArray(-0.9534355, dtype=float32)}

which I guess are the initial values according to default initialisation strategy but

  1. in the initial_params should have a leading shape according to the number of chains to run: how to proceed?
  2. in find_valid_initial_params how to change in practice init_strategy argument to get a set of parameters according to my prior: should I set init_strategy=init_to_sample?


Yes, init_to_sample will draw a set of random samples from priors per chain for you. I’m not sure why you need to use find_valid_initial_params though. The second reference in @martinjankowiak comment shows you how to specify init_strategy for NUTS.

FYI, the docs of find_valid_initial_params says that you can specify a batch of rng_key to get a batch of initial params.

Oh, Thanks @fehiepsi and @martinjankowiak, at first look I had not realized that I could have directly pass the strategy in the list of args of NUTS. Ok.