I want to sample a variable from the gaussian mixture distribution. But I can not find the class in numpyro distribution module similar to MixtureSameFamily in the pytorch.
Is there a way to create a mixture same family distribution in numpyro?
import numpyro.contrib.tfp.distributions as tfd
gm = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(
probs=[0.3, 0.7]),
components_distribution=tfd.Normal(
loc=[-1., 1], # One for each component.
scale=[0.1, 0.5])) # And same here.
The current wrapper seems to not work for special distributions like MixtureSameFamily. For now, to make it work, you’ll need to do
from numpyro.contrib.tfp.distributions import TFPDistribution
import tensorflow_probability...distributions as tfd
def model():
x = sample("x", TFPDistribution(tfd.MixtureSampleFamily(...)))
In the long term, I think it would be nice to use TFP distributions directly in the sample primitive:
import tensorflow_probability...distributions as tfd
def model():
x = sample("x", tfd.MixtureSampleFamily(...))
@ChernovAndrey FYI, thanks to @tcbegley in the numpyro master branch, you can use TFP distributions directly in sample primitive - no need to import them from contrib.tfp.distributions.