Hello,
I am having difficulty constructing a model that supports enumeration over discrete latent variables.
The model, takes a known distribution(p_x
below), and a point (obs
below). The model, roughly, is meant to detect if the point is an outlier in the distribution using a Bernoulli prior, and if they are outliers, replace them with more realistic values (adjusted
below). I have a simplified version of the model running using MixedHMC:
from numpyro.infer import MCMC, HMC, NUTS, MixedHMC
from numpyro.distributions.transforms import AffineTransform
import numpyro
from numpyro import deterministic, sample
from numpyro.distributions import Bernoulli, Normal, TransformedDistribution
import jax.numpy as jnp
from jax import random
def model(
obs: jnp.ndarray, p_x: dist.Distribution, outlier_probs: jnp.ndarray
):
dim = len(obs)
is_outlier = sample("is_outlier", Bernoulli(probs=outlier_probs))
error_slab = sample("error_slab", Normal(0, jnp.full((dim, ), 5)))
error_slab = jnp.where(is_outlier, error_slab, 0)
adjusted_obs = deterministic("adjusted", obs - error_slab)
numpyro.sample(
"y",
TransformedDistribution(p_x, AffineTransform(error_slab, 1)),
obs=obs,
)
dim = 2
probs = jnp.array([0.1, 0.1])
kernel = HMC(model, target_accept_prob=0.9)
kernel = MixedHMC(kernel)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=5000, progress_bar=True)
key, mcmc_key = random.split(random.PRNGKey(0))
y = jnp.array([0, 4])
p_x = dist.Normal(0, jnp.ones(2))
model_args = [y, p_x, probs]
mcmc.run(mcmc_key, *model_args)
samples = mcmc.get_samples()
key, p_x_key = random.split(key, 2)
adjusted_samps = p_x.sample(p_x_key, samples["adjusted"].shape)
I can see the model roughly works as expected:
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
plt.scatter(p_x_samps[:, 0], p_x_samps[:, 1], s=0.1, alpha=0.5)
plt.scatter(samples["adjusted"][:, 0], samples["adjusted"][:, 1], s=0.1, alpha=0.5)
plt.scatter(y[0], y[1])
However, in low dimensional examples, I want to enumerate over the Bernoulli random variable. I understand that I have to use plates somehow, but I’m not particularly sure how. I made the following adjustments and the model runs with standard HMC:
def model(
obs: jnp.ndarray, p_x: dist.Distribution, outlier_prob: float = 0.2
):
d = len(obs)
obs = jnp.expand_dims(obs, 0)
with numpyro.plate("d", d, dim=-1):
error_slab = numpyro.sample("error_slab", dist.Normal(0, 5))
is_outlier = numpyro.sample("is_outlier", dist.Bernoulli(probs=outlier_prob))
error_slab = jnp.where(is_outlier, error_slab, 0)
adjusted_obs = numpyro.deterministic("adjusted", obs - error_slab)
with numpyro.plate("obs", 1, dim=-1):
numpyro.sample(
"y",
TransformedDistribution(p_x, AffineTransform(error_slab, 1)),
obs=obs,
)
However, it does not give me the expected results (a lot of the “adjusted” samples are very unlikely in p_x
).
Thanks!