Initialization problem sampling from discrete distribution

Hi,

I get the following error message from the code. I am not sure why it is occurring.

DATA = np.array([
[0, 0, 1, 5.5],
[0, 1, 2, 4.5],
[0, 2, 3, 3.5],
[1, 0, 1, 3],
[1, 1, 2, 6],
[1, 2, 3, 1],
[2, 0, 1, 6],
[2, 1, 2, 5],
[2, 2, 3, 4],
[3, 0, 1, 1],
[3, 1, 2, 9],
[3, 3, 2, 10]
])

DATA = jax.numpy.asarray(DATA)
K = len(list(set([int(row[0]) for row in DATA])))
N = len(list(set([int(row[1]) for row in DATA])))
R = len(DATA)
print(K)
print(N)
print(R)

def model():
a = numpyro.sample(“competenceMean”, Exponential(0.4))
b = numpyro.sample(“competenceVariance”, Exponential(0.4))

with numpyro.plate('compet', K):
    competence = numpyro.sample("competence", Gamma(a, b))

    C = 3
    culture_id = numpyro.sample("culture_id", Categorical(1/C * jnp.ones(C).astype(int)))
    print("culture_id:")
    print(culture_id)

with numpyro.plate('consens',N*C):
    consensus = numpyro.sample("consensus", Beta(8, 8))
    consensus = jnp.reshape(consensus, (N, C))

with numpyro.plate("data_loop", R) as i:
    mu = consensus[DATA[i, 1].astype(int), culture_id[DATA[i, 0].astype(int)]]
    print(mu)
    rating = numpyro.sample("rating", Normal(mu, 0.01), obs=DATA[i, 3].astype(float)/10)

kernel = NUTS(model, init_strategy=numpyro.infer.init_to_sample)
mcmc = MCMC(kernel, num_warmup=100, num_samples=100,chain_method=‘parallel’,num_chains=1, jit_model_args=True)
mcmc.run(random.PRNGKey(0))
samples = mcmc.get_samples(group_by_chain=True)
mcmc.print_summary()
diagnos = numpyro.diagnostics.summary(mcmc.get_samples(group_by_chain=True))

raise ValueError(“Output mismatch: {} vs {}”.format(x.output, output))
ValueError: Output mismatch: Bint[12] vs Bint[3]

Are you using enumeration? It seems that your model violates this assumption. It is better to use DiscreteHMCGibbs I think.

Yep. DiscereteHMCGibbs solved the issue. Thank you