I’m having this issue in a model I’m working with too. The challenge when post-processing is it’s not always simple to tell which cluster is which when the distributions overlap. The other thing you need to be super careful with, especially with a GMM, is to make sure that the label doesn’t switch mid-chain.
Betancourt has a blog about this you may have already come across, which is pretty helpful: Identifying Bayesian Mixture Models
I have some code to do the postprocessing. Start by making an arviz
object from your model (called az_obj
here). This code assumes there is a variable called assignment
in your model which represents the cluster label. You also need to supply a dict called adjustment_vars
that has the names of the variables you want to shuffle as keys, and the index of the dimension you want to shuffle them in as the values.
import xarray as xr
import numpy as np
cutoff = 0.75
chain_res = np.zeros((az_obj.posterior.assignment.shape[0], az_obj.posterior.assignment.shape[2])) * np.nan
for chain in range(az_obj.posterior.assignment.shape[0]):
for ind in range(az_obj.posterior.assignment.shape[2]):
freqs = {}
dat = az_obj.posterior.assignment[chain, :, ind]
for val in np.unique(dat):
freqs[val] = np.sum(val == np.asarray(dat)) / len(dat)
if (np.asarray(list(freqs.values())) > cutoff).any():
chain_res[chain, ind] = list(freqs.keys())[list(freqs.values()).index(np.max(list(freqs.values())))]
u_res = np.unique(chain_res.T, axis=0)
mapping = u_res[np.isnan(u_res).sum(axis=1) == 0, :]
for k in adjustment_vars:
for chain in range(az_obj.posterior.chain.shape[0]):
inds = np.concatenate((mapping[:, chain], np.arange(az_obj.posterior[k].shape[adjustment_vars[k]])))
for i in mapping[:, chain]:
inds = np.delete(inds, np.argwhere(inds == i)[-1])
dim_name = f"{k}_dim_0"
az_obj.posterior[k][[chain], :, :] = (az_obj.posterior[k][[chain], :, :]
.sel({dim_name: inds})
.reset_index(dim_name)
.reindex({dim_name: np.arange(len(inds))}))
# Assignment
for chain in range(1, mapping.shape[1]):
assignment_mapping = {}
for label in range(mapping.shape[0]):
assignment_mapping[mapping[label, 0].astype(int)] = mapping[label, chain].astype(int)
immut_res = az_obj.posterior.assignment[chain, :, :]
for k in assignment_mapping:
az_obj.posterior.assignment[chain, :, :] = xr.where(
immut_res == assignment_mapping[k],
k,
az_obj.posterior.assignment[chain, :, :]
)
EDIT - I didn’t spot this was a resurrected old thread when I posted this! Hopefully it’s still useful