Label Switching - GMM

Hey, I implemented a gaussian mixture model over the Iris dataset (with 3 different features / clusters).
I am not sure how to deal with “label switching” issue in my model.
I have seen several techniques to deal with this issue, among them is constraints over the prior, i.e to formulate the prior in such a way as to ensure there is only one posterior mode (eg. order the means of the mixture components), but I understand that this technique isn’t generally used.
Another strategy is to ignore the problem during sampling, and then post-process the output to re-label the components to keep the labels consistent , which I am not sure how to implement.
I couldn’t find any solution in the NumPyro documentation or anywhere else.
http://num.pyro.ai/en/latest/examples/annotation.html

I would really be glad to get some help with this issue.

Hi @nai, you can find some suggestions in this Stan reference or in the reference in that annotation example. From what I know, label switching is still a tricky issue.

1 Like

Hi @fehiepsi, sorry I’m new to numpyro. I’m still not sure how should I do if I’d like to add constraints on mu. I’m trying to model a mixed distribution from two gaussian components N(mu_k, sigma_k) (k=0,1). Can I just sort them like this?

# priors
    mu = numpyro.sample("mu",dist.Normal(0,1))
    nu = numpyro.sample("nu",dist.InverseGamma(1,1))

    with numpyro.plate("status",K):
        mu_k_raw = numpyro.sample("mu_k_raw",dist.Normal(mu, jnp.sqrt(nu)))

    mu_k = numpyro.deterministic("mu_k", jnp.sort(mu_k_raw))

I’m a little afraid that this will influence its convergence? what should I do if I want to just add constraints on sampling? That is, if I get sample mu_k[0]>mu_k[1], I’ll do resampling.

I’m having this issue in a model I’m working with too. The challenge when post-processing is it’s not always simple to tell which cluster is which when the distributions overlap. The other thing you need to be super careful with, especially with a GMM, is to make sure that the label doesn’t switch mid-chain.

Betancourt has a blog about this you may have already come across, which is pretty helpful: Identifying Bayesian Mixture Models

I have some code to do the postprocessing. Start by making an arviz object from your model (called az_obj here). This code assumes there is a variable called assignment in your model which represents the cluster label. You also need to supply a dict called adjustment_vars that has the names of the variables you want to shuffle as keys, and the index of the dimension you want to shuffle them in as the values.

import xarray as xr
import numpy as np

cutoff = 0.75 

chain_res = np.zeros((az_obj.posterior.assignment.shape[0], az_obj.posterior.assignment.shape[2])) * np.nan
for chain in range(az_obj.posterior.assignment.shape[0]):
  for ind in range(az_obj.posterior.assignment.shape[2]):
      
      freqs = {}
      dat = az_obj.posterior.assignment[chain, :, ind]
      
      for val in np.unique(dat):
          freqs[val] = np.sum(val == np.asarray(dat)) / len(dat)
          
      if (np.asarray(list(freqs.values())) > cutoff).any():
          chain_res[chain, ind] = list(freqs.keys())[list(freqs.values()).index(np.max(list(freqs.values())))]
          
u_res = np.unique(chain_res.T, axis=0)
mapping = u_res[np.isnan(u_res).sum(axis=1) == 0, :]

for k in adjustment_vars:
  for chain in range(az_obj.posterior.chain.shape[0]):
      inds = np.concatenate((mapping[:, chain], np.arange(az_obj.posterior[k].shape[adjustment_vars[k]])))
      for i in mapping[:, chain]:
          inds = np.delete(inds, np.argwhere(inds == i)[-1])

      dim_name = f"{k}_dim_0"
      az_obj.posterior[k][[chain], :, :] = (az_obj.posterior[k][[chain], :, :]
          .sel({dim_name: inds})
          .reset_index(dim_name)
          .reindex({dim_name: np.arange(len(inds))}))

# Assignment
for chain in range(1, mapping.shape[1]):
  assignment_mapping = {}
  for label in range(mapping.shape[0]):
      assignment_mapping[mapping[label, 0].astype(int)] = mapping[label, chain].astype(int)
  
  immut_res = az_obj.posterior.assignment[chain, :, :]
  for k in assignment_mapping:
      az_obj.posterior.assignment[chain, :, :] = xr.where(
          immut_res == assignment_mapping[k],
          k,
          az_obj.posterior.assignment[chain, :, :]
      )

EDIT - I didn’t spot this was a resurrected old thread when I posted this! Hopefully it’s still useful