Hi,
I get the following error message from the code. I am not sure why it is occurring.
DATA = np.array([
[0, 0, 1, 5.5],
[0, 1, 2, 4.5],
[0, 2, 3, 3.5],
[1, 0, 1, 3],
[1, 1, 2, 6],
[1, 2, 3, 1],
[2, 0, 1, 6],
[2, 1, 2, 5],
[2, 2, 3, 4],
[3, 0, 1, 1],
[3, 1, 2, 9],
[3, 3, 2, 10]
])
DATA = jax.numpy.asarray(DATA)
K = len(list(set([int(row[0]) for row in DATA])))
N = len(list(set([int(row[1]) for row in DATA])))
R = len(DATA)
print(K)
print(N)
print(R)
def model():
a = numpyro.sample(“competenceMean”, Exponential(0.4))
b = numpyro.sample(“competenceVariance”, Exponential(0.4))
with numpyro.plate('compet', K):
competence = numpyro.sample("competence", Gamma(a, b))
C = 3
culture_id = numpyro.sample("culture_id", Categorical(1/C * jnp.ones(C).astype(int)))
print("culture_id:")
print(culture_id)
with numpyro.plate('consens',N*C):
consensus = numpyro.sample("consensus", Beta(8, 8))
consensus = jnp.reshape(consensus, (N, C))
with numpyro.plate("data_loop", R) as i:
mu = consensus[DATA[i, 1].astype(int), culture_id[DATA[i, 0].astype(int)]]
print(mu)
rating = numpyro.sample("rating", Normal(mu, 0.01), obs=DATA[i, 3].astype(float)/10)
kernel = NUTS(model, init_strategy=numpyro.infer.init_to_sample)
mcmc = MCMC(kernel, num_warmup=100, num_samples=100,chain_method=‘parallel’,num_chains=1, jit_model_args=True)
mcmc.run(random.PRNGKey(0))
samples = mcmc.get_samples(group_by_chain=True)
mcmc.print_summary()
diagnos = numpyro.diagnostics.summary(mcmc.get_samples(group_by_chain=True))
raise ValueError(“Output mismatch: {} vs {}”.format(x.output, output))
ValueError: Output mismatch: Bint[12] vs Bint[3]