Mixture model with discrete data in Numpyro

Hi fehiepsi,

Thank you so much for your help. This fixes the problem! For completeness in case someone else is interested.

Custom distribution:

class MultivariateBernoulli(dist.Distribution):
    support = constraints.real_vector

    def __init__(self, phi):
        super(MultivariateBernoulli, self).__init__(batch_shape=phi.shape[:-1], event_shape=phi.shape[-1:]) 
        self.phi = phi

    def sample(self, key, sample_shape=()):
        raise NotImplementedError

    def log_prob(self, value):
    
        ps_clamped = clamp_probs(self.phi)

        return jnp.sum(
            jnp.asarray(
              xlogy(value, ps_clamped) + xlog1py(1 - value, -ps_clamped) # assuming independence of the variables.
            ),
            axis=-1
        )

Model:

@config_enumerate
def discrete_mixture_model(K, X=None):
    
    N, D = X.shape
    cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(0.5 * jnp.ones(K))) 
        
    with numpyro.plate('components', D):
        with numpyro.plate("cluster", K):
            phi = numpyro.sample('phi', dist.Beta(2.0, 2.0)) 

    with numpyro.plate('data', N):
        
        assignment = numpyro.sample('assignment', dist.CategoricalProbs(cluster_proba)) 
        
        numpyro.sample(
            'obs', 
            MultivariateBernoulli(phi[assignment, :]), 
            obs=X,
        )

When I check the log probs of Bernoulli(phi).to_event(1) I get the same results as for my MultivariateBernoulli so you are right about that too. In the long run I need to have a multivariate Bernoulli distribution that includes correlations between the variables (as in this paper [1206.1874] Multivariate Bernoulli distribution) so I will probably stick with the custom class for that.

Thanks again! :slight_smile:

4 Likes