Posterior Mixture Probabilities for Enumerated Inference

If I run a simple Gaussian Mixture Model with enumerated MCMC in numpyro, I can retrospectively sample (or choose the MAP) cluster assignments for each row for each MCMC sample with infer_discrete.

How do I instead get the posterior assignment probabilities for each row for each sample?

(e.g. If I have a 3 cluster mixture model and the MAP cluster assignment for row 0 is 1, I might expect to have the function return [.05, .9, .05]. It should be possible to compute this for each MCMC sample/row/cluster by taking p(row_i in cluster_j at sample_k) = \frac{density(row_i, cluster_j | latents_k, data)}{\sum_{j=0}^{j=nclusters - 1}{density(row_i, cluster_j | latents_k, data)}} a la Mixture Models | Dan Foreman-Mackey)

I think this would involve manipulating the _log_enum_density function to change the way likelihoods are accumulated in funsor, but I do not understand the funsor code well enough to come up with a solution.

Here’s a solution for the simple GMM case.

import jax

import funsor
from numpyro.contrib.funsor.enum_messenger import trace as packed_trace
from numpyro.contrib.funsor.infer_util import plate_to_enum_plate

from numpyro.distributions.util import is_identically_one
from numpyro.handlers import substitute

funsor.set_backend("jax")

def get_cluster_logits(model, model_args, model_kwargs, params):
    """
    get the cluster logits for each row
    
    model: enumerated model
    model_args: tuple
    model_kwargs: dict
    params: dict
    """
    
    model = substitute(model, data=params)
    with plate_to_enum_plate():
        model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    
    log_factors = {}
    
    for site in model_trace.values():
        if site["type"] == "sample":
            value = site["value"]
            intermediates = site["intermediates"]
            scale = site["scale"]
            if intermediates:
                log_prob = site["fn"].log_prob(value, intermediates)
            else:
                log_prob = site["fn"].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            dim_to_name = site["infer"]["dim_to_name"]
            log_prob_factor = funsor.to_funsor(
                log_prob, output=funsor.Real, dim_to_name=dim_to_name
            )

            log_factors[site["name"]] = log_prob_factor

    log_factors = jax.tree_util.tree_map(lambda x: x.data, log_factors) # remove from tensor object

    # apply cluster priors to likelihood
    clust_priors_each_clust = log_factors['assignments']
    data_lik_each_clust = log_factors['observed_data']

    num = clust_priors_each_clust + data_lik_each_clust
    
    # we sum over the cluster axis:
    # p(row_i = cluster_0) + p(row_i = cluster_1) + ...
    denom = jax.scipy.special.logsumexp(num, axis=0) 

    return num - denom

# example usage
import numpyro.contrib.funsor as funsor
from numpyro.contrib.funsor.infer_util import config_enumerate

enum_model = funsor.enum(config_enumerate(model), first_available_dim=...)

get_cluster_logits(enum_model, (...), {...}, ...)


thanks for sharing. informative