Hi fehiepsi,
Thank you so much for your help. This fixes the problem! For completeness in case someone else is interested.
Custom distribution:
class MultivariateBernoulli(dist.Distribution):
support = constraints.real_vector
def __init__(self, phi):
super(MultivariateBernoulli, self).__init__(batch_shape=phi.shape[:-1], event_shape=phi.shape[-1:])
self.phi = phi
def sample(self, key, sample_shape=()):
raise NotImplementedError
def log_prob(self, value):
ps_clamped = clamp_probs(self.phi)
return jnp.sum(
jnp.asarray(
xlogy(value, ps_clamped) + xlog1py(1 - value, -ps_clamped) # assuming independence of the variables.
),
axis=-1
)
Model:
@config_enumerate
def discrete_mixture_model(K, X=None):
N, D = X.shape
cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(0.5 * jnp.ones(K)))
with numpyro.plate('components', D):
with numpyro.plate("cluster", K):
phi = numpyro.sample('phi', dist.Beta(2.0, 2.0))
with numpyro.plate('data', N):
assignment = numpyro.sample('assignment', dist.CategoricalProbs(cluster_proba))
numpyro.sample(
'obs',
MultivariateBernoulli(phi[assignment, :]),
obs=X,
)
When I check the log probs of Bernoulli(phi).to_event(1)
I get the same results as for my MultivariateBernoulli
so you are right about that too. In the long run I need to have a multivariate Bernoulli distribution that includes correlations between the variables (as in this paper [1206.1874] Multivariate Bernoulli distribution) so I will probably stick with the custom class for that.
Thanks again!