I’m trying to learn how to infer discrete latent variables from observed continuous data properly in numpyro. The basic example I’m attempting to do is a distribution of heights where the sex of the observation is unknown and inferred.
Here is the code to generate the data.
import pandas as pd
import numpy as np
import jax.numpy as jnp
import numpyro
from numpyro.infer import MCMC, NUTS, HMC, MixedHMC
import numpyro.distributions as dist
from jax import ops, random, vmap
import arviz as az
import jax
import seaborn as sns
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=6'
numpyro.set_host_device_count(4)
N = 100
var_num = 2
with numpyro.handlers.seed(rng_seed=0):
mixing_dist = dist.Categorical(probs=jnp.ones(var_num) / var_num)
component_dist = dist.Normal(loc=jnp.array([5*12+10, 5*12+4]), scale=jnp.array([3, 2]))
mixture = dist.MixtureSameFamily(mixing_dist, component_dist).expand([N])
y = mixture.sample(jax.random.PRNGKey(42))
sns.histplot(y)
And here are two different models to estimate the model:
def model3(heights):
num_classes = 2
N = len(heights)
with numpyro.plate('class', num_classes):
mu = numpyro.sample(
'mu',
dist.Normal(0, 100).expand([num_classes])
)
sigma = numpyro.sample(
'sigma',
dist.HalfNormal(10).expand([num_classes])
)
with numpyro.plate('items', N):
pi = numpyro.sample("pi", dist.Beta(0.5, 0.5))
c = numpyro.sample('mixing_dist', dist.Bernoulli(probs=pi))
numpyro.sample(
'component_dist',
dist.Normal(loc=mu[c], scale=sigma[c]),
obs=heights
)
kernel = MixedHMC(HMC(model3))
m = MCMC(kernel, num_warmup=8000, num_samples=2000, num_chains=4, progress_bar=True)
m.run(random.PRNGKey(2), y)
samples = m.get_samples()
data = az.from_numpyro(m)
az.plot_trace(data, var_names=['mu', 'sigma'])
az.summary(data, var_names=['mu', 'sigma'])
And also this (which is more easily extended past the 2 class case):
def model4(heights):
num_classes = 2
N = len(heights)
with numpyro.plate('class', num_classes):
mu = numpyro.sample(
'mu',
dist.Normal(0, 100).expand([num_classes])
)
sigma = numpyro.sample(
'sigma',
dist.HalfNormal(10).expand([num_classes])
)
with numpyro.plate('items', N):
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
c = numpyro.sample('mixing_dist', dist.Categorical(probs=pi))
numpyro.sample(
'component_dist',
dist.Normal(loc=mu[c], scale=sigma[c]),
obs=heights
)
kernel = MixedHMC(HMC(model4))
m = MCMC(kernel, num_warmup=8000, num_samples=2000, num_chains=4, progress_bar=True)
m.run(random.PRNGKey(2), y)
samples = m.get_samples()
data = az.from_numpyro(m)
az.plot_trace(data, var_names=['mu', 'sigma'])
az.summary(data, var_names=['mu', 'sigma'])
Both of those seem to work (although the second one gives some warnings). However the r_hats are not good, because which category represents male and female can switch places within different chains.
mean sd hdi_3% hdi_97% mcse_m mcse_sd blk tail r_hat
mu[0] 67.703 3.281 63.955 71.907 1.618 1.237 6.0 149.0 1.74
mu[1] 67.666 3.205 63.991 71.771 1.579 1.207 6.0 129.0 1.74
sigma[0] 2.315 0.546 1.532 3.388 0.186 0.136 9.0 200.0 1.37
sigma[1] 2.358 0.562 1.554 3.423 0.197 0.144 8.0 126.0 1.41
Is there a better way to handle this to get consistent labeling of the discrete variables so the r_hats aren’t thrown off by different index assignments to the latent variables? Would it be better to frame the second mu as being the first mu plus some positive only parameter? And how to move to 3+ classes and retain consistency?
Also, is there a way to do inference with MixtureSameFamily
? I couldn’t get that to work. I would get an error message:
ValueError: The mixing distribution need to be a numpyro.distributions.Categorical. However, it is of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>