# Custom distribution for mixture model

Hi there.

I’m trying to implement a multivariate Bernoulli distribution that includes correlations between the variables (described in this paper: https://arxiv.org/pdf/1206.1874)

The model looks like this:

``````@config_enumerate
def discrete_mixture_model(
K,
N,
D_discrete,
X_discrete=None,
):

cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(1.0 * jnp.ones(K)))

with numpyro.plate("discrete_cluster", K):
phi = numpyro.sample('phi', dist.Dirichlet(1.0 * jnp.ones(2 ** D_discrete))) #dist.Dirichlet(alpha)) #

with numpyro.plate('data', N):

assignment = numpyro.sample('assignment', dist.CategoricalProbs(cluster_proba))

numpyro.sample(
'obs',
MultivariateBernoulli(phi[assignment, :]),
obs=X_discrete,
)
``````

And the distribution looks like this:

``````class MultivariateBernoulli(dist.Distribution):
arg_constraints = {'phi': dist.constraints.simplex}
support = dist.constraints.real_vector

def __init__(self, phi):
super(MultivariateBernoulli, self).__init__(batch_shape=phi.shape[:-1], event_shape=phi.shape[-1:])
self.phi = phi

def sample(self, key, sample_shape=()):
raise NotImplementedError

def log_prob(self, value):

return jax.scipy.special.xlogy(
jnp.array(
[jnp.array([v[j] for v, j in zip(zip(value.T, (1-value).T), i)]).prod(axis=0)
for i in itertools.product(*[[0,1] for _ in range(value.shape[-1])])]
).T,
self.phi
).sum(axis=-1)
``````

Essentially the way the distribution works is by having a separate probability for each possible configuration of the variables, hence the presence of the `itertools.product` in the above. In case this isn’t clear, I’ve implemented the distribution for the specific case of two variables

``````class MultivariateBernoulli2D(dist.Distribution):
arg_constraints = {'phi': dist.constraints.simplex}
support = dist.constraints.real_vector

def __init__(self, phi):
super(MultivariateBernoulli2D, self).__init__(batch_shape=phi.shape[:-1], event_shape=phi.shape[-1:])
self.phi = phi

def sample(self, key, sample_shape=()):
raise NotImplementedError

def log_prob(self, value):

inds = [slice(None)] * self.phi.ndim

inv_value = 1 - value
return (jax.scipy.special.xlogy(x=inv_value[:, 0] * inv_value[:, 1], y=self.phi[tuple(inds[0:-1] + [0])]) +
jax.scipy.special.xlogy(x=value[:, 0] * inv_value[:, 1], y=self.phi[tuple(inds[0:-1] + [1])]) +
jax.scipy.special.xlogy(x=inv_value[:, 0] * value[:, 1], y=self.phi[tuple(inds[0:-1] + [2])]) +
jax.scipy.special.xlogy(x=value[:, 0] * value[:, 1], y=self.phi[tuple(inds[0:-1] + [3])]))
``````

I think I am doing the `log_prob` calculation correctly as the two methods produce the same values for the same data, but when I try and fit the model using MCMC I don’t get anything like sensible results. The traces of `cluster_proba` jump back and forth between 0 and 1 during sampling when the clusters are about equally probable. This works with minor modifications (priors from a Beta distribution rather than a Dirichlet distribution for example) when using the builtin Bernoulli distribution that assumes independent variables, and also with a custom distribution that implements the independent Bernoulli distribution.

Does anyone have any ideas about what’s going wrong?

Many thanks!

check if the gradients of log_prob look reasonable?

How would I do that? Would something like `jax.grad` be the thing to use for this?

yes

Thanks for your help so far. I think the gradients are okay. I’ve done a bit more work and set up a google colab to demonstrate the problem here: Google Colab

I can run the model with one cluster, and it reproduces the correct parameters from the simulated data, but when I run the model with more than one cluster it puts all samples in one cluster and incorrectly estimates the coefficients. I assume the issue is that I’ve have defined the distribution incorrectly. I would be very grateful if you could take a look and give me some pointers.

Thanks

Hi @martinjankowiak @fehiepsi . Sorry to be a pain, but I’d really appreciate it if you could have a look at this. I’m sure the solution will be something relatively simple but I don’t know where to look in the documentation to get any guidance on this.

Thanks

sorry but i have no idea what kind of distribution you’re aiming for and don’t have time to read up on it and code like

``````[jnp.array([v[j] for v, j in zip(zip(value.T, (1-value).T), i)]).prod(axis=0)
for i in itertools.product(*[[0,1] for _ in range(value.shape[-1])])]
``````

isn’t exactly transparent.

• does your log_prob return the correct shapes? e.g. compare to the analogous log_prob for vanilla bernoulli
• is `prod` going to be numerically stable? why not use log prod = sum log?

``````class MultivariateBernoulli2D(dist.Distribution):
# Note, this implementation is specifically
# for two correlated Bernoulli variables
support = dist.constraints.real_vector

def __init__(self, phi_zero, phi_coef, phi_cross):
super(MultivariateBernoulli2D, self).__init__(
batch_shape=phi_coef.shape[:-1],
event_shape=phi_coef.shape[-1:]
)

self.phi_zero = phi_zero
self.phi_coef = phi_coef
self.phi_cross = phi_cross

def sample(self, key, sample_shape=()):
raise NotImplementedError

def log_prob(self, value):

zeroth_term = self.phi_zero
first_term = (value * self.phi_coef).sum(axis=-1)
second_term = jnp.prod(value, axis=-1, keepdims=False) * self.phi_cross
return zeroth_term + first_term + second_term
``````

`value` above can either be zero or one, so `prod(value)` will be one when both values are one and zero otherwise. I don’t think there are any problems with numerical instability but let me know if I’ve got that wrong.

Here is the model.

``````@config_enumerate
def discrete_mixture_model(
K,
N,
D_discrete,
X_discrete=None,
):

cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(1.0 * jnp.ones(K)))

# There are some places in the code to account for differences in tensor
# shapes when the number of clusters is 1. Flagged here and in comments
# below
if K > 1:
with numpyro.plate("discrete_cluster", K):
phi = numpyro.sample('phi', dist.Dirichlet(1.0 * jnp.ones(2 ** D_discrete)))
else:
phi = numpyro.sample('phi', dist.Dirichlet(1.0 * jnp.ones(2 ** D_discrete)))

# This is specific to the case of two correlated Bernoulli variables.
phi_zero = jnp.log(phi[..., 0])
phi_coef = jnp.array([jnp.log(phi[..., 1] / phi[..., 0]), jnp.log(phi[..., 2] / phi[..., 0])])
phi_cross = jnp.log(phi[..., 0] * phi[..., 3] / (phi[..., 1] * phi[..., 2]))

# K = 1
if len(phi_zero.shape) == 0:
phi_zero = jnp.array([phi_zero])
phi_coef = phi_coef.reshape(1, -1)
phi_cross = jnp.array([phi_cross])

with numpyro.plate('data', N):

assignment = numpyro.sample('assignment', dist.CategoricalProbs(cluster_proba))

# K = 1
if K == 1:
target = MultivariateBernoulli2D(phi_zero, phi_coef, phi_cross)
else:
target = MultivariateBernoulli2D(phi_zero[assignment], phi_coef[assignment, :], phi_cross[assignment])

numpyro.sample(
'obs',
target,
obs=X_discrete,
)
``````

Here is the trace plot for the two cluster case

As you can see, all samples get erroneously dumped into one cluster.

For multivariate 2D, you can sum of probs over all possible values to check whether the implementation is correct or not. If the posterior is multimodal, then you need other mcmc algorithms.

Thanks. I’ve been working on this and I’m pretty sure the problem is in the `log_prob` method, specifically with the shape of each of the terms. Is there any documentation anywhere with details on how to create custom distributions for use in numpyro models? I guess the enumeration over the discrete latent variable is adding some complexity, but I’m really struggling to understand how things are working differently when the distribution is called from the model.

Thanks!

You can check how the shapes work at Tensor shapes in Pyro — Pyro Tutorials 1.9.0 documentation

``````d: batch_shape + event_shape
value: sample_batch_shape + event_shape
``````

If you think this is the issue of log_prob, you can check:

• whether sum of probs over all possible values is 1: e.g. you can try `exp(d.log_prob(array([1, 0]))) + exp(d.log_prob(array([1, 1]))) + ...`
• whether the shapes satisfy the batch semantics above

The second one can be checked using vmap, i.e. comparing:

``````d(batch_phi).log_prob(value)
``````

with

``````jax.vmap(lambda phi: d(phi).log_prob(value))(batch_phi)
``````

If this is batch issue, check out the `Writing parallelizable code¶ ` section at Tensor shapes in Pyro — Pyro Tutorials 1.9.0 documentation

1 Like

Hi @fehiepsi . Thanks again for your help.

I’ve modified the distribution code so all the checks you suggested work (I think) i.e.

``````phi_m = np.random.dirichlet(np.ones(4))
x = jnp.array([[0,0], [1,0], [0,1], [1,1]]) # this is all possible states
d_m = MultivariateBernoulli2D(phi_m)
print("sum check", jnp.exp(d_m.log_prob(x)).sum(axis=-1))
``````

prints 1.0, and all the shape parameters of the distribution match the independent variables Bernoulli distribution `dist.Bernoulli(phi).to_event(1)`.

Also, the jax vmap check that you suggested returns the same shape: `(50, 4)`

``````phi_m = np.random.dirichlet(np.ones(4), size=50)
print(MultivariateBernoulli2D(phi_m).log_prob(x).shape)
print(jax.vmap(lambda phi: MultivariateBernoulli2D(phi).log_prob(x))(phi_m).shape)
``````

Unfortunately this still doesn’t fix the problem. Here is the distribution and model code:

``````class MultivariateBernoulli2D(dist.Distribution):
support = dist.constraints.real_vector

def __init__(self, phi):
super(MultivariateBernoulli2D, self).__init__(
batch_shape=phi.shape[:-1],
event_shape= (2,)
)

self.phi = phi

def sample(self, key, sample_shape=()):
raise NotImplementedError

def log_prob(self, value):

val_m = 1 - value

first_term = jnp.prod(value, axis=-1, keepdims=False) * jnp.log(self.phi[..., [0]])
second_term = jnp.prod(jnp.concatenate([value[:, [0]], val_m[:, [1]]], axis=-1), axis=-1, keepdims=False) * jnp.log(self.phi[..., [1]])
third_term = jnp.prod(jnp.concatenate([val_m[:, [0]], value[:, [1]]], axis=-1), axis=-1, keepdims=False) * jnp.log(self.phi[..., [2]])
last_term = jnp.prod(val_m, axis=-1, keepdims=False) * jnp.log(self.phi[..., [3]])

out = first_term + second_term + third_term + last_term

return out.squeeze()

@config_enumerate
def discrete_mixture_model(
K,
N,
D_discrete,
X_discrete=None,
):

cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(1.0 * jnp.ones(K)))

if K > 1:
with numpyro.plate("discrete_cluster", K):
phi = numpyro.sample('phi', dist.Dirichlet(2.0 * jnp.ones(2 ** D_discrete)))
else:
phi = numpyro.sample('phi', dist.Dirichlet(2.0 * jnp.ones(2 ** D_discrete)))

with numpyro.plate('data', N):

assignment = numpyro.sample('assignment', dist.CategoricalProbs(cluster_proba))

if K == 1:
target = MultivariateBernoulli2D(phi)
else:
target = MultivariateBernoulli2D(phi[assignment])

out = numpyro.sample(
'obs',
target,
obs=X_discrete,
)
``````

Trace plot when fitting to data with 2 clusters:

Could you compare vmap results, rather than just shapes? Then gradually add more dimensions. This will confirm that your log_prob is correct (in the sense of representing the log density of a distribution). Inference is another story (you might need to deal with label switching issues - i’m not sure).

The vmap method produces the same results as calling the log_prob directly, so I think that is working correctly. Do you know why there would be label switching problems with this but not with the built in Bernoulli distribution `dist.Bernoulli(phi).to_event(1)`?

Given `cluster_proba` and `assignment`, `new_cluster_proba = 1 - cluster_proba` and `new_phi = 1 - phi (or its permutation)`, I feel like they will give you the same density. (I’m not sure what is the `phi[..., 3]` in your code, what do you mean by “such issue does not happen with the built in `dist.Bernoulli(phi).to_event(1)`”).

If I modify the model to use the distribution we discussed before in this thread: Mixture model with discrete data in Numpyro - #8 by jim (assuming independent Bernoullis) then I don’t get the label switching behaviour and all the model parameters are estimated pretty well. The change to this custom distribution does take a bit of modification to the model, as the parameters of the independent case are simple probabilities that I’ve used Beta priors and sampling for them, but the case here has a constraint that the parameters are a simplex so I need to use something like a Dirichlet prior instead.

Even without the label switching here, the model looks like it would still be dumping all observations in the same cluster, which is not what I want.

Thanks.

Do you think the posterior is easy to sample from (e.g. single modal)? If not, you might need Example: Neural Transport — NumPyro documentation or something else.

I’m pretty sure the posterior should be well behaved, ie single modal.