# Posterior Mixture Probabilities for Enumerated Inference

If I run a simple Gaussian Mixture Model with enumerated MCMC in numpyro, I can retrospectively sample (or choose the MAP) cluster assignments for each row for each MCMC sample with infer_discrete.

How do I instead get the posterior assignment probabilities for each row for each sample?

(e.g. If I have a 3 cluster mixture model and the MAP cluster assignment for row 0 is 1, I might expect to have the function return [.05, .9, .05]. It should be possible to compute this for each MCMC sample/row/cluster by taking p(row_i in cluster_j at sample_k) = \frac{density(row_i, cluster_j | latents_k, data)}{\sum_{j=0}^{j=nclusters - 1}{density(row_i, cluster_j | latents_k, data)}} a la Mixture Models | Dan Foreman-Mackey)

I think this would involve manipulating the _log_enum_density function to change the way likelihoods are accumulated in funsor, but I do not understand the funsor code well enough to come up with a solution.

Here’s a solution for the simple GMM case.

import jax

import funsor
from numpyro.contrib.funsor.enum_messenger import trace as packed_trace
from numpyro.contrib.funsor.infer_util import plate_to_enum_plate

from numpyro.distributions.util import is_identically_one
from numpyro.handlers import substitute

funsor.set_backend("jax")

def get_cluster_logits(model, model_args, model_kwargs, params):
"""
get the cluster logits for each row

model: enumerated model
model_args: tuple
model_kwargs: dict
params: dict
"""

model = substitute(model, data=params)
with plate_to_enum_plate():
model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)

log_factors = {}

for site in model_trace.values():
if site["type"] == "sample":
value = site["value"]
intermediates = site["intermediates"]
scale = site["scale"]
if intermediates:
log_prob = site["fn"].log_prob(value, intermediates)
else:
log_prob = site["fn"].log_prob(value)

if (scale is not None) and (not is_identically_one(scale)):
log_prob = scale * log_prob

dim_to_name = site["infer"]["dim_to_name"]
log_prob_factor = funsor.to_funsor(
log_prob, output=funsor.Real, dim_to_name=dim_to_name
)

log_factors[site["name"]] = log_prob_factor

log_factors = jax.tree_util.tree_map(lambda x: x.data, log_factors) # remove from tensor object

# apply cluster priors to likelihood
clust_priors_each_clust = log_factors['assignments']
data_lik_each_clust = log_factors['observed_data']

num = clust_priors_each_clust + data_lik_each_clust

# we sum over the cluster axis:
# p(row_i = cluster_0) + p(row_i = cluster_1) + ...
denom = jax.scipy.special.logsumexp(num, axis=0)

return num - denom

# example usage
import numpyro.contrib.funsor as funsor
from numpyro.contrib.funsor.infer_util import config_enumerate

enum_model = funsor.enum(config_enumerate(model), first_available_dim=...)

get_cluster_logits(enum_model, (...), {...}, ...)



thanks for sharing. informative