Hi, I have been trying to estimate the theta’s in the eight-school model using different centeredness for each theta. But, for some reason, I am receiving an assertion error. The code is as follows:
def eight_schools_noncentered(J, sigma, y=None):
mu = numpyro.sample('mu', dist.Normal(2, 5))
tau = numpyro.sample('tau', dist.HalfCauchy(5))
# with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=lambd, shape_params=(J,))}):
theta = numpyro.sample('theta', dist.Normal(mu, tau),sample_shape=(J,))
numpyro.sample('obs', dist.Normal(theta, sigma), obs=y,sample_shape=(J,))
from numpyro.handlers import reparam
reparam_model = reparam(eight_schools_noncentered, config={"theta": LocScaleReparam(centered=best_c_jax, shape_params=(J,))})
Please help me figure out the error.
i don’t know what’s going on but if you want actionable advice you should provide a complete runnable script. i’d start by removing the sample_shape
args. none of the example models in numpyro follow this pattern. i’d revert to the usage documented in the repo:
def eight_schools(J, sigma, y=None):
mu = numpyro.sample('mu', dist.Normal(0, 5))
tau = numpyro.sample('tau', dist.HalfCauchy(5))
with numpyro.plate('J', J):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
Hi, sorry for the confusion caused. By passing a centeredness array, I wish to provide varied centeredness to each theta using LocScaleReparam in eight school models. For example, in the case of scalar-centeredness, we can do the following:
def eight_schools_noncentered(J, sigma, lambd, y=None):
mu = numpyro.sample('mu', dist.Normal(2, 5))
tau = numpyro.sample('tau', dist.HalfCauchy(5))
with numpyro.plate('J', J):
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=lambd)}):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
nuts_kernel = NUTS(eight_schools_noncentered)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, lambd = 0.1,y=y, extra_fields=('potential_energy',))
mcmc.print_summary(exclude_deterministic=False)
I wish to replicate the same for a centeredness array. I tried it by passing an array in place of 0.1, but I received an assertion error. Can you please help me modify the code?
I found this way to pass centeredness as array:
def eight_schools_noncentered(J, sigma, y=None):
mu = numpyro.sample('mu', dist.Normal(2, 5))
tau = numpyro.sample('tau', dist.HalfCauchy(5))
theta = numpyro.sample('theta', dist.Normal(jnp.full(J,mu), tau))
numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
trial_c = jnp.full(J,0)
from numpyro.handlers import reparam
reparam_model = reparam(eight_schools_noncentered, config={"theta": LocScaleReparam(centered=jnp.array(trial_c))})
nuts_kernel = NUTS(reparam_model)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma,y=y, extra_fields=('potential_energy',))
mcmc.print_summary(exclude_deterministic=False)
Though I am not sure whether it is doing it correctly or not, can you please verify this?
yes, you can use array for the centered parameter.