Hi,
I am trying to assign people into “cultures” (clusters) via stick breaking process. In this process, the rhat for discrete variable “culture_id” is NaN for many participants. Why? How can I fix it?
def model(features,
rates,
cultures,
participants,
stimuli):
# latent_dim):
# Stick-breaking prior
alpha = numpyro.sample('alpha', Gamma(1, 10))
with numpyro.plate('weights', cultures - 1):
v = numpyro.sample('v', Beta(1, alpha))
sb_trans = StickBreakingTransform()
with numpyro.plate('culture_loop', participants):
culture_id = numpyro.sample("culture_id", Categorical(sb_trans(v))
a = numpyro.sample("competenceMean", Normal(0, 1))
b = numpyro.sample("competenceVariance", Gamma(1, 1))
bias_prior = numpyro.sample("biasPrior", Gamma(1, 1))
scale_prior = numpyro.sample("scalePrior", Gamma(1, 1))
with numpyro.plate('c_compet', participants):
bias = numpyro.sample("bias", Normal(0, bias_prior))
scaling = numpyro.sample("scaling", Normal(0, scale_prior))
with numpyro.plate('competence_plate', cultures):
competence = numpyro.sample("competence", Normal(a, b))
# Sample coefficients to project image features into latent space.
with numpyro.plate('latent_image_coefficients', 512):
image_coef_sigprior = numpyro.sample("image_coef_sigprior", Gamma(1, 0.1))
with numpyro.plate('latent_image_coefficients_plate', cultures):
image_feature_coefficient = numpyro.sample("image_feature_coefficient",
Normal(0, 1 / image_coef_sigprior))
with numpyro.plate("data_loop", features.shape[0]):
culture_assignment = culture_id[features[:, 0]]
mu = vmap(lambda vis_vec, vis_coefs, scale, b:
# Gather operations in kernel to allow aggressive XLA compilation
jnp.exp(scale) * jnp.dot(vis_vec, vis_coefs) + b)(
visual_face_vectors[features[:, 1] - 1],
image_feature_coefficient[culture_assignment],
scaling[features[:, 0]],
bias[features[:, 0]]
)
# Derive precision from participant competence
precision = jnp.exp(competence[culture_assignment, features[:, 0]]) * jnp.exp(scaling[features[:, 0]])
# Sample the rating.
numpyro.sample("rating", Normal(mu, 1 / precision), obs=logit(rates))
def main(start_time, *model_args, **model_kwargs):
kernel = DiscreteHMCGibbs(NUTS(model, init_strategy=numpyro.infer.init_to_sample, step_size=1e-3, max_tree_depth=8))
mcmc = MCMC(kernel, num_warmup=300, num_samples=500, chain_method='vectorized', num_chains=1) # why jit args?
mcmc.run(rng_key, *model_args, **model_kwargs)
samples = mcmc.get_samples(group_by_chain=True)
mcmc.print_summary()
diagnos = numpyro.diagnostics.summary(mcmc.get_samples(group_by_chain=True))
print("--- %s seconds ---" % (time.time() - start_time))
with open('dominantUpdated.pickle', 'wb') as handle:
pickle.dump(diagnos, handle, protocol=pickle.HIGHEST_PROTOCOL)
np.save('dominantUpdated.npy', samples['culture_id'])```