Dirichlet Process Gaussian Mixture Model - TypeError for Small Concentration Parameter

I’m trying to use NumPyro to infer cluster assignments in a Dirichlet Process Gaussian Mixture model. My code is pretty short and simple:

    num_obs, obs_dim = observations.shape

    if sampling_max_num_clusters is None:
        # multiply by 2 for safety
        sampling_max_num_clusters = 2 * int(alpha * np.log(1 + num_obs / alpha))

    def mix_weights(beta):
        beta1m_cumprod = jnp.cumprod(1 - beta, axis=-1)
        term1 = jnp.pad(beta, (0, 1), mode='constant', constant_values=1.)
        term2 = jnp.pad(beta1m_cumprod, (1, 0), mode='constant', constant_values=1.)
        return jnp.multiply(term1, term2)

    def model(obs):
        with numpyro.plate('beta_plate', sampling_max_num_clusters - 1):
            beta = numpyro.sample(
                numpyro.distributions.Beta(1, alpha))

        with numpyro.plate('mean_plate', sampling_max_num_clusters):
            mean = numpyro.sample(
                    gaussian_mean_prior_cov_scaling * jnp.eye(obs_dim)))

        with numpyro.plate('data', num_obs):
            z = numpyro.sample(
                    gaussian_cov_scaling * jnp.eye(obs_dim)),

    hmc_kernel = numpyro.infer.NUTS(model)
    kernel = numpyro.infer.DiscreteHMCGibbs(inner_kernel=hmc_kernel)
    mcmc = numpyro.infer.MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=True)
    mcmc.run(random.PRNGKey(0), obs=observations)
    # mcmc.print_summary()
    samples = mcmc.get_samples()

When I try running this with alpha=0.21, I have no problem. When I try setting alpha=0.11 or alpha=0.01, I get the following error:

TypeError: iota shape must have every element be nonnegative, got (-1,).

What does this mean and how do I fix it?

1 Like

Ah! I quickly found my error. My number of mixture components was set to 0, and the model couldn’t handle this. The error TypeError: iota shape must have every element be nonnegative, got (-1,). was pretty useless though!

1 Like

Agreed that the error message is misleading! I guess you meant that we should raise a better error message if size = 0 here? If so, could you make a github issue for this? :slight_smile:

Issue posted here: https://github.com/pyro-ppl/numpyro/issues/953

1 Like