Discrete Gibbs sampler returns many NaN rhat


I am trying to assign people into “cultures” (clusters) via stick breaking process. In this process, the rhat for discrete variable “culture_id” is NaN for many participants. Why? How can I fix it?

def model(features,
        #   latent_dim):
    # Stick-breaking prior
    alpha = numpyro.sample('alpha', Gamma(1, 10))
    with numpyro.plate('weights', cultures - 1):
        v = numpyro.sample('v', Beta(1, alpha))

    sb_trans = StickBreakingTransform()
    with numpyro.plate('culture_loop', participants):
        culture_id = numpyro.sample("culture_id", Categorical(sb_trans(v))

    a = numpyro.sample("competenceMean", Normal(0, 1))
    b = numpyro.sample("competenceVariance", Gamma(1, 1))
    bias_prior = numpyro.sample("biasPrior", Gamma(1, 1))
    scale_prior = numpyro.sample("scalePrior", Gamma(1, 1))
    with numpyro.plate('c_compet', participants):
        bias = numpyro.sample("bias", Normal(0, bias_prior))
        scaling = numpyro.sample("scaling", Normal(0, scale_prior))
        with numpyro.plate('competence_plate', cultures):
            competence = numpyro.sample("competence", Normal(a, b))
    # Sample coefficients to project image features into latent space.
    with numpyro.plate('latent_image_coefficients', 512):
        image_coef_sigprior = numpyro.sample("image_coef_sigprior", Gamma(1, 0.1))
        with numpyro.plate('latent_image_coefficients_plate', cultures):
            image_feature_coefficient = numpyro.sample("image_feature_coefficient",
                                                    Normal(0, 1 / image_coef_sigprior))

    with numpyro.plate("data_loop", features.shape[0]):
        culture_assignment = culture_id[features[:, 0]]

    mu = vmap(lambda vis_vec, vis_coefs, scale, b:
              # Gather operations in kernel to allow aggressive XLA compilation
              jnp.exp(scale) * jnp.dot(vis_vec, vis_coefs) + b)(
        visual_face_vectors[features[:, 1] - 1],
        scaling[features[:, 0]],
        bias[features[:, 0]]
    # Derive precision from participant competence
    precision = jnp.exp(competence[culture_assignment, features[:, 0]]) * jnp.exp(scaling[features[:, 0]])

    # Sample the rating.
    numpyro.sample("rating", Normal(mu, 1 / precision), obs=logit(rates))

def main(start_time, *model_args, **model_kwargs):
    kernel = DiscreteHMCGibbs(NUTS(model, init_strategy=numpyro.infer.init_to_sample, step_size=1e-3, max_tree_depth=8))
    mcmc = MCMC(kernel, num_warmup=300, num_samples=500, chain_method='vectorized', num_chains=1)  # why jit args?
    mcmc.run(rng_key, *model_args, **model_kwargs)
    samples = mcmc.get_samples(group_by_chain=True)

    diagnos = numpyro.diagnostics.summary(mcmc.get_samples(group_by_chain=True))
    print("--- %s seconds ---" % (time.time() - start_time))
    with open('dominantUpdated.pickle', 'wb') as handle:
        pickle.dump(diagnos, handle, protocol=pickle.HIGHEST_PROTOCOL)

    np.save('dominantUpdated.npy', samples['culture_id'])```

I’m not sure if Rhat is meaningful for discrete variables. You can ignore it I guess.

The assignments seem to be static though where first sample assignments are same as the last ones. Is that normal as well?

Sampling discrete variables is a hard problem so I guess what you observed is normal. Probably there’s better sampler for your problem - I’m not sure. So far I just got luck with some simple models.