MixtureGeneral in pyro?

Hi!
I see there’s the MixtureSameFamily class both in pyro and numpyro, whereas numpyro additionally has MixtureGeneral. Is this difference due to some underlying limitation of pyro vis-a-vis numpyro? How should I go about achieving some equivalent of MixtureGeneral in pyro? Below is a toy example of a model in numpyro that I’m curious how to do in pyro as well. Thanks!

import numpy as np
from scipy.stats import expon, lognorm, bernoulli
from matplotlib import pyplot as plt

from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import numpyro.handlers as handl
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO, autoguide

numpyro.util.set_host_device_count(4)

rng_key = random.PRNGKey(0)
rng = np.random.default_rng(0)

# DATA GENERATION

## ground truth
w = 0.3
expon_lambda = 0.5
lognorm_mu = 1.5
lognorm_sigma = 0.2

## pdf
x = np.linspace(0, 10, 1_000)
y = w * expon(scale=expon_lambda).pdf(x) + (1 - w) * lognorm(scale=np.exp(lognorm_mu), s=lognorm_sigma).pdf(x)

## rvs
N = 10_000
_w = bernoulli(p=w).rvs(size=N, random_state=rng)
expon_rvs = expon(scale=expon_lambda).rvs(size=N, random_state=rng)
lognorm_rvs = lognorm(scale=np.exp(lognorm_mu), s=lognorm_sigma).rvs(size=N, random_state=rng)
rvs = _w * expon_rvs + (1 - _w) * lognorm_rvs
# MODEL

def model(n_obs):
    latent_expon_lambda = numpyro.sample("latent_expon_lambda", dist.TruncatedNormal(loc=0, scale=1.0, low=0, high=3.0))
    latent_lognorm_mu = numpyro.sample("latent_lognorm_mu", dist.Normal(loc=1, scale=1.0))
    latent_lognorm_sigma = numpyro.sample("latent_lognorm_sigma", dist.HalfNormal(scale=1.0))
    latent_w = numpyro.sample("latent_w", dist.Dirichlet(concentration=jnp.array([1.0, 1.0])))
    component_dists = [
        dist.Exponential(rate=1/latent_expon_lambda),
        dist.LogNormal(loc=latent_lognorm_mu, scale=latent_lognorm_sigma),
    ]
    with numpyro.plate("obs_plate", n_obs):
        obs = numpyro.sample(
            "obs",
            dist.MixtureGeneral(
                mixing_distribution=dist.Categorical(probs=latent_w),
                component_distributions=component_dists))
# PRIOR PREDICTIVE

rng_key, rng_key_ = random.split(rng_key)
prior_pred = Predictive(model, num_samples=2_000)(rng_key_, n_obs=1)

_prior_obs = prior_pred["obs"].reshape(-1)
prior_obs = _prior_obs[_prior_obs < np.percentile(_prior_obs, 95)]

# MCMC

mcmc = MCMC(
    sampler=NUTS(handl.condition(model, {"obs": rvs})),
    num_warmup=1_000,
    num_samples=1_000,
    num_chains=4,
)
rng_key, rng_key_ = random.split(rng_key)
mcmc.run(rng_key_, n_obs=len(rvs))
mcmc.print_summary()
mcmc_samples = mcmc.get_samples()

# POSTERIOR PREDICTIVE

rng_key, rng_key_ = random.split(rng_key)
post_pred = Predictive(model, posterior_samples=mcmc_samples)(rng_key_, n_obs=1)

# PLOT GROUND TRUTH, OBSERVATIONS, PRIOR- AND POSTERIOR PREDICTIVES

fig, ax = plt.subplots()
ax.plot(x, y, label="ground truth pdf");
ax.hist(rvs, density=True, bins=80, histtype="step", label="observations");
ax.hist(prior_obs, density=True, bins=60, histtype="step", label="prior predictive");
ax.hist(post_pred["obs"].reshape(-1), density=True, bins=60, histtype="step", label="posterior predictive");
ax.legend();

you can compute the observation log_prob by hand from Categorical.log_prob() etc and use a factor statement to encode the observation. there is no MixtureGeneral equivalent in pyro.

1 Like