Hi @fehiepsi , I’m working on a mcmc in numpyro. And I want to set a initial value to my priors for example:
def model(M, N, ntob, ntmis, T, ID, ya_ob,tp, xs, xa):
jax.config.update(“jax_enable_x64”, True)
delta = jnp.float64(1e-9)
x = jnp.cumsum(jnp.ones(M))
# Priors for linear model coefficients
b1s = numpyro.sample(‘b1s’, dist.Normal(0, 5))
b1a = numpyro.sample(‘b1a’, dist.Normal(0, 5))
alpha_l_s = numpyro.sample('alpha_l_s', dist.TruncatedNormal(loc=0, scale=5, low=0))
alpha_l_a = numpyro.sample('alpha_l_a', dist.TruncatedNormal(loc=0, scale=5, low=0))
rho_l_s = numpyro.sample('rho_l_s', dist.TruncatedNormal(loc=0, scale=5, low=0))
rho_l_a = numpyro.sample('rho_l_a', dist.TruncatedNormal(loc=0, scale=5, low=0))
alpha_m_s = numpyro.sample('alpha_m_s', dist.TruncatedNormal(loc=0, scale=5, low=0))
alpha_m_a = numpyro.sample('alpha_m_a', dist.TruncatedNormal(loc=0, scale=40, low=0))
rho_m_s = numpyro.sample('rho_m_s', dist.TruncatedNormal(loc=0, scale=5, low=0))
rho_m_a = numpyro.sample('rho_m_a', dist.TruncatedNormal(loc=0, scale=5, low=0))
sig_amt = numpyro.sample('sig_amt', dist.HalfCauchy(scale=5))
L_Omega = numpyro.sample('L_Omega', dist.LKJCholesky(2, concentration=3))
delta_mu = numpyro.sample('delta_mu', dist.Normal(0, 5), sample_shape=(2,))
sig_s = numpyro.sample('sig_s', dist.HalfCauchy(scale=5))
sig_a = numpyro.sample('sig_a', dist.HalfCauchy(scale=5))
z_s = numpyro.sample('z_s', dist.Normal(loc=0, scale=1), sample_shape=(N,))
z_a = numpyro.sample('z_a', dist.Normal(loc=0, scale=1), sample_shape=(N,))
and I will use
initial_values = {
‘b1s’: jnp.float64(b1s_true + 2),
‘b1a’: jnp.float64(b1a_true - 2),
‘sig_s’: jnp.float64(sig_s_true +0.5 ),
‘sig_a’: jnp.float64(sig_a_true + 0.5),
‘delta_s’: jnp.float64(delta_s_true + 1),
‘delta_a’: jnp.float64(delta_a_true +2),
‘sig_amt’: jnp.float64(sig_amt_true + 0.2),
‘cor_e’: jnp.float64(cor_e_true + 0.4),
‘alpha_l_s’: jnp.float64(alpha_l_s_true),
‘rho_l_s’: jnp.float64(rho_l_s_true),
‘alpha_l_a’: jnp.float64(alpha_l_a_true),
‘rho_l_a’: jnp.float64(rho_l_a_true),
‘alpha_m_s’: jnp.float64(alpha_m_s_true),
‘rho_m_s’: jnp.float64(rho_m_s_true),
‘alpha_m_a’: jnp.float64(alpha_m_a_true + 0.1),
‘rho_m_a’: jnp.float64(rho_m_a_true)
}
strategy = init_to_value(values=initial_values)
nuts_kernel = NUTS(SOGP,
adapt_step_size=True,
# step_size=0.05,
max_tree_depth=10,
target_accept_prob=0.80,
init_strategy=strategy)
as an initization for the priors. However, this seems not working. When I get samples from the mcmc, the samples never starts from the given initial value. Even when I set num_warmup = 0, the samples still never starts from the given initial value. Can I ask if I did anything wrong?