Multivariate Gaussian Mixture runs but does not work

Hi everyone,

I am new to Pyro and I wanted to implement a multivariate gaussian mixture model.
The model runs without any errors, but when applied to simple and well separated data, it fails to infer the true parameters.

I used the following model in Numpyro:

def model(K,dim,data=None):
    cluster_proba = numpyro.sample('cluster_proba',dist.Dirichlet(0.5 * jnp.ones(K)))
    with numpyro.plate('components', K):
        sigma = numpyro.sample('sigma', dist.HalfCauchy(scale=10))
    with numpyro.plate('data', len(data)):
        assignment = numpyro.sample('assignment', dist.Categorical(cluster_proba),infer={"enumerate": "parallel"}) 
        numpyro.sample('obs', dist.MultivariateNormal(locs[assignment,:][1], sigma[assignment][1]*jnp.eye(dim)), obs=data)

I used NUTS do to do the inference:

rng_key = jax.random.PRNGKey(0)

num_warmup, num_samples = 1000, 5000

kernel = NUTS(model)
mcmc = MCMC(
), data=data,K=3,dim=3)
posterior_samples = mcmc.get_samples()

Even though the model runs, it seems that it doesn’t manage to converge to correct values when applied to well separated (simulated data). In particular the weights are not updated and remain at 1/3 (when applied to three components)

I simulated 3D data coming from 3 components using the code below:

n = 2500 # Total number of samples
k = 3  # Number of clusters
dim=3 # Number of dimension
p_real = np.array([0.1, 0.5, 0.4])  # Probability of choosing each cluster

mu0=[10, 10, 10]
mu1=[-5, -3, -3]
mu2=[1, 1, 1]


clusters = np.random.choice(k, size=1, p=p_real)
data=np.random.multivariate_normal(mus[clusters[0]],sigmas[clusters[0]]*np.eye(dim), (1))
for i in range(1,n):
    clusters = np.random.choice(k, size=1, p=p_real)
    data_point=np.random.multivariate_normal(mu, sigma,(1))
    data=np.concatenate((data,data_point), axis=0)


This gives the results below that are incoherent with the data provided:

If anyone can tell me what I am doing wrong that would be much appreciated.


hello your indexing is probably wrong. instead of blithely sticking tensors into distributions i suggest using assert/print statements to understand the consequence of the indexing operations you’re using. something like

my_locs = locs[assignment,:][1]
print("my_locs.shape", my_locs.shape)
expected_shape = (K, len(data), dim)  # or whatever it is
assert my_locs.shape == expected_shape
numpyro.sample('obs', dist.MultivariateNormal(my_locs, ...))

Why we have ...[1] here? I suspect that locs[assignment,:] won’t work for batches of locs and assignment. You might want to use the helper Vindex as in annotation example.

Thanks a lot for the help, that fixed it !