I’m trying to use NumPyro to infer cluster assignments in a Dirichlet Process Gaussian Mixture model. My code is pretty short and simple:
num_obs, obs_dim = observations.shape
if sampling_max_num_clusters is None:
# multiply by 2 for safety
sampling_max_num_clusters = 2 * int(alpha * np.log(1 + num_obs / alpha))
def mix_weights(beta):
beta1m_cumprod = jnp.cumprod(1 - beta, axis=-1)
term1 = jnp.pad(beta, (0, 1), mode='constant', constant_values=1.)
term2 = jnp.pad(beta1m_cumprod, (1, 0), mode='constant', constant_values=1.)
return jnp.multiply(term1, term2)
def model(obs):
with numpyro.plate('beta_plate', sampling_max_num_clusters - 1):
beta = numpyro.sample(
'beta',
numpyro.distributions.Beta(1, alpha))
with numpyro.plate('mean_plate', sampling_max_num_clusters):
mean = numpyro.sample(
'mean',
numpyro.distributions.MultivariateNormal(
jnp.zeros(obs_dim),
gaussian_mean_prior_cov_scaling * jnp.eye(obs_dim)))
with numpyro.plate('data', num_obs):
z = numpyro.sample(
'z',
numpyro.distributions.Categorical(mix_weights(beta=beta)).mask(False))
numpyro.sample(
'obs',
numpyro.distributions.MultivariateNormal(
mean[z],
gaussian_cov_scaling * jnp.eye(obs_dim)),
obs=obs)
hmc_kernel = numpyro.infer.NUTS(model)
kernel = numpyro.infer.DiscreteHMCGibbs(inner_kernel=hmc_kernel)
mcmc = numpyro.infer.MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=True)
mcmc.run(random.PRNGKey(0), obs=observations)
# mcmc.print_summary()
samples = mcmc.get_samples()
When I try running this with alpha=0.21
, I have no problem. When I try setting alpha=0.11
or alpha=0.01
, I get the following error:
TypeError: iota shape must have every element be nonnegative, got (-1,).
What does this mean and how do I fix it?