Varied centeredness for thetas in eight school model

Hi, I have been trying to estimate the theta’s in the eight-school model using different centeredness for each theta. But, for some reason, I am receiving an assertion error. The code is as follows:

def eight_schools_noncentered(J, sigma, y=None):
    mu = numpyro.sample('mu', dist.Normal(2, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    # with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=lambd, shape_params=(J,))}):
    theta = numpyro.sample('theta', dist.Normal(mu, tau),sample_shape=(J,))
    numpyro.sample('obs', dist.Normal(theta, sigma), obs=y,sample_shape=(J,))

from numpyro.handlers import reparam
reparam_model = reparam(eight_schools_noncentered, config={"theta": LocScaleReparam(centered=best_c_jax, shape_params=(J,))})

Please help me figure out the error.

i don’t know what’s going on but if you want actionable advice you should provide a complete runnable script. i’d start by removing the sample_shape args. none of the example models in numpyro follow this pattern. i’d revert to the usage documented in the repo:

def eight_schools(J, sigma, y=None):
    mu = numpyro.sample('mu', dist.Normal(0, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    with numpyro.plate('J', J):
        theta = numpyro.sample('theta', dist.Normal(mu, tau))
        numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

Hi, sorry for the confusion caused. By passing a centeredness array, I wish to provide varied centeredness to each theta using LocScaleReparam in eight school models. For example, in the case of scalar-centeredness, we can do the following:

def eight_schools_noncentered(J, sigma, lambd, y=None):
        mu = numpyro.sample('mu', dist.Normal(2, 5))
        tau = numpyro.sample('tau', dist.HalfCauchy(5))
        with numpyro.plate('J', J):
            with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=lambd)}):
                theta = numpyro.sample('theta', dist.Normal(mu, tau))
            numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
nuts_kernel = NUTS(eight_schools_noncentered)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma, lambd = 0.1,y=y, extra_fields=('potential_energy',))
mcmc.print_summary(exclude_deterministic=False)

I wish to replicate the same for a centeredness array. I tried it by passing an array in place of 0.1, but I received an assertion error. Can you please help me modify the code?

I found this way to pass centeredness as array:

def eight_schools_noncentered(J, sigma, y=None):
    mu = numpyro.sample('mu', dist.Normal(2, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    theta = numpyro.sample('theta', dist.Normal(jnp.full(J,mu), tau))
    numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
trial_c = jnp.full(J,0)
from numpyro.handlers import reparam
reparam_model = reparam(eight_schools_noncentered, config={"theta": LocScaleReparam(centered=jnp.array(trial_c))})
nuts_kernel = NUTS(reparam_model)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, J, sigma,y=y, extra_fields=('potential_energy',))
mcmc.print_summary(exclude_deterministic=False)

Though I am not sure whether it is doing it correctly or not, can you please verify this?