Hi there.
I’m trying to implement a multivariate Bernoulli distribution that includes correlations between the variables (described in this paper: https://arxiv.org/pdf/1206.1874)
The model looks like this:
@config_enumerate
def discrete_mixture_model(
K,
N,
D_discrete,
X_discrete=None,
):
cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(1.0 * jnp.ones(K)))
with numpyro.plate("discrete_cluster", K):
phi = numpyro.sample('phi', dist.Dirichlet(1.0 * jnp.ones(2 ** D_discrete))) #dist.Dirichlet(alpha)) #
with numpyro.plate('data', N):
assignment = numpyro.sample('assignment', dist.CategoricalProbs(cluster_proba))
numpyro.sample(
'obs',
MultivariateBernoulli(phi[assignment, :]),
obs=X_discrete,
)
And the distribution looks like this:
class MultivariateBernoulli(dist.Distribution):
arg_constraints = {'phi': dist.constraints.simplex}
support = dist.constraints.real_vector
def __init__(self, phi):
super(MultivariateBernoulli, self).__init__(batch_shape=phi.shape[:-1], event_shape=phi.shape[-1:])
self.phi = phi
def sample(self, key, sample_shape=()):
raise NotImplementedError
def log_prob(self, value):
return jax.scipy.special.xlogy(
jnp.array(
[jnp.array([v[j] for v, j in zip(zip(value.T, (1-value).T), i)]).prod(axis=0)
for i in itertools.product(*[[0,1] for _ in range(value.shape[-1])])]
).T,
self.phi
).sum(axis=-1)
Essentially the way the distribution works is by having a separate probability for each possible configuration of the variables, hence the presence of the itertools.product
in the above. In case this isn’t clear, I’ve implemented the distribution for the specific case of two variables
class MultivariateBernoulli2D(dist.Distribution):
arg_constraints = {'phi': dist.constraints.simplex}
support = dist.constraints.real_vector
def __init__(self, phi):
super(MultivariateBernoulli2D, self).__init__(batch_shape=phi.shape[:-1], event_shape=phi.shape[-1:])
self.phi = phi
def sample(self, key, sample_shape=()):
raise NotImplementedError
def log_prob(self, value):
inds = [slice(None)] * self.phi.ndim
inv_value = 1 - value
return (jax.scipy.special.xlogy(x=inv_value[:, 0] * inv_value[:, 1], y=self.phi[tuple(inds[0:-1] + [0])]) +
jax.scipy.special.xlogy(x=value[:, 0] * inv_value[:, 1], y=self.phi[tuple(inds[0:-1] + [1])]) +
jax.scipy.special.xlogy(x=inv_value[:, 0] * value[:, 1], y=self.phi[tuple(inds[0:-1] + [2])]) +
jax.scipy.special.xlogy(x=value[:, 0] * value[:, 1], y=self.phi[tuple(inds[0:-1] + [3])]))
I think I am doing the log_prob
calculation correctly as the two methods produce the same values for the same data, but when I try and fit the model using MCMC I don’t get anything like sensible results. The traces of cluster_proba
jump back and forth between 0 and 1 during sampling when the clusters are about equally probable. This works with minor modifications (priors from a Beta distribution rather than a Dirichlet distribution for example) when using the builtin Bernoulli distribution that assumes independent variables, and also with a custom distribution that implements the independent Bernoulli distribution.
Does anyone have any ideas about what’s going wrong?
Many thanks!