Sample from the mixture same family distribution

Hello!

I want to sample a variable from the gaussian mixture distribution. But I can not find the class in numpyro distribution module similar to MixtureSameFamily in the pytorch.

Is there a way to create a mixture same family distribution in numpyro?

Thank you!

Have you tried MixtureSameFamily, it is a wrapper around TensorFlow probability’s MixtureSameFamily.

I tried

import numpyro.contrib.tfp.distributions as tfd
gm = tfd.MixtureSameFamily(
    mixture_distribution=tfd.Categorical(
        probs=[0.3, 0.7]),
    components_distribution=tfd.Normal(
      loc=[-1., 1],       # One for each component.
      scale=[0.1, 0.5]))  # And same here.

and got the following error

Hmm, I am not sure about the error. I know there is a feature request to add a version of MixtureSameFamily from Pyro.

Following the discussion from the original thread that spawned the above [FR], @fehiepsi gave a snippet to generalize a mixture from the same family found here. Maybe it can help.

Thank you for the answer!

I apapted code from here for the Gaussian mixture and it is suits me

class MixtureGaussian(dist.Distribution):
    def __init__(self, loc, scale, mixing_probs, validate_args=None):
        expand_shape = jax.lax.broadcast_shapes(
            jnp.shape(loc), jnp.shape(scale), jnp.shape(mixing_probs))
        self._gaussian = dist.Normal(loc=loc, scale=scale).expand(expand_shape)
        self._categorical = dist.Categorical(jnp.broadcast_to(mixing_probs, expand_shape))
        super(MixtureGaussian, self).__init__(batch_shape=expand_shape[:-1], validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        key, key_idx = random.split(key)
        samples = self._gaussian.sample(key, sample_shape)
        ind = self._categorical.sample(key_idx, sample_shape)
        return jnp.take_along_axis(samples, ind[..., None], -1)[..., 0]

    def log_prob(self, value):
        probs_mixture = self._gaussian.log_prob(value[..., None])
        sum_probs = self._categorical.logits + probs_mixture
        return jax.nn.logsumexp(sum_probs, axis=-1)
1 Like

Glad it helped! =)

The current wrapper seems to not work for special distributions like MixtureSameFamily. For now, to make it work, you’ll need to do

from numpyro.contrib.tfp.distributions import TFPDistribution
import tensorflow_probability...distributions as tfd

def model():
    x = sample("x", TFPDistribution(tfd.MixtureSampleFamily(...)))

In the long term, I think it would be nice to use TFP distributions directly in the sample primitive:

import tensorflow_probability...distributions as tfd

def model():
    x = sample("x", tfd.MixtureSampleFamily(...))

I just made a FR for this. :slight_smile:

@ChernovAndrey FYI, thanks to @tcbegley in the numpyro master branch, you can use TFP distributions directly in sample primitive - no need to import them from contrib.tfp.distributions.

1 Like