Converting model to support enumerate discrete latents

Hello,

I am having difficulty constructing a model that supports enumeration over discrete latent variables.

The model, takes a known distribution(p_x below), and a point (obs below). The model, roughly, is meant to detect if the point is an outlier in the distribution using a Bernoulli prior, and if they are outliers, replace them with more realistic values (adjusted below). I have a simplified version of the model running using MixedHMC:

from numpyro.infer import MCMC, HMC, NUTS, MixedHMC
from numpyro.distributions.transforms import AffineTransform
import numpyro
from numpyro import deterministic, sample
from numpyro.distributions import Bernoulli, Normal, TransformedDistribution
import jax.numpy as jnp
from jax import random

def model(
    obs: jnp.ndarray, p_x: dist.Distribution, outlier_probs: jnp.ndarray
):
    dim = len(obs)
    is_outlier = sample("is_outlier", Bernoulli(probs=outlier_probs))
    error_slab = sample("error_slab", Normal(0, jnp.full((dim, ), 5)))
    error_slab = jnp.where(is_outlier, error_slab, 0)
    adjusted_obs = deterministic("adjusted", obs - error_slab)
    numpyro.sample(
        "y",
        TransformedDistribution(p_x, AffineTransform(error_slab, 1)),
        obs=obs,
    )

dim = 2
probs = jnp.array([0.1, 0.1])
kernel = HMC(model, target_accept_prob=0.9)
kernel = MixedHMC(kernel)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=5000, progress_bar=True)
key, mcmc_key = random.split(random.PRNGKey(0))

y = jnp.array([0, 4])
p_x = dist.Normal(0, jnp.ones(2))

model_args = [y, p_x, probs]
mcmc.run(mcmc_key, *model_args)


samples = mcmc.get_samples()
key, p_x_key = random.split(key, 2)
adjusted_samps = p_x.sample(p_x_key, samples["adjusted"].shape)

I can see the model roughly works as expected:

import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
plt.scatter(p_x_samps[:, 0], p_x_samps[:, 1], s=0.1, alpha=0.5)
plt.scatter(samples["adjusted"][:, 0], samples["adjusted"][:, 1], s=0.1, alpha=0.5)
plt.scatter(y[0], y[1])

However, in low dimensional examples, I want to enumerate over the Bernoulli random variable. I understand that I have to use plates somehow, but I’m not particularly sure how. I made the following adjustments and the model runs with standard HMC:

def model(
    obs: jnp.ndarray, p_x: dist.Distribution, outlier_prob: float = 0.2
):
    d = len(obs)
    obs = jnp.expand_dims(obs, 0)

    with numpyro.plate("d", d, dim=-1):
        error_slab = numpyro.sample("error_slab", dist.Normal(0, 5))
        is_outlier = numpyro.sample("is_outlier", dist.Bernoulli(probs=outlier_prob))

    error_slab = jnp.where(is_outlier, error_slab, 0)
    adjusted_obs = numpyro.deterministic("adjusted", obs - error_slab)
    
    with numpyro.plate("obs", 1, dim=-1):
        numpyro.sample(
            "y",
            TransformedDistribution(p_x, AffineTransform(error_slab, 1)),
            obs=obs,
        )

However, it does not give me the expected results (a lot of the “adjusted” samples are very unlikely in p_x).

Thanks!

I think you need to loop over d, using this pattern

Thanks for the help! I think I’ve got it working now. Here is the updated model:

from numpyro.infer import MCMC, NUTS
import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
from jax import random

def model(
    obs: jnp.ndarray,
    p_x: dist.Distribution,
    outlier_probs: jnp.ndarray,
    slab_std: float = 10):
    d = len(obs)

    with numpyro.plate("d", d, dim=-1):
        error_slab = numpyro.sample("error_slab", dist.Normal(0, slab_std))

    masked_errors = []
    for k in range(d):
        is_outlier_k = numpyro.sample(f"is_outlier_{k}", dist.Bernoulli(probs=outlier_probs[k]), infer={'enumerate': 'parallel'})
        masked_errors.append(error_slab[None, k]*is_outlier_k)
    masked_errors = jnp.concatenate(jnp.broadcast_arrays(*masked_errors), axis=-1)

    with numpyro.plate("obs", 1):
        numpyro.sample(
            "y",
            TransformedDistribution(p_x, AffineTransform(masked_errors, 1)),
            obs=obs,
        )

dim = 2
probs = jnp.full((dim,), 0.2)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=5000)
key, mcmc_key = random.split(random.PRNGKey(0))
y = jnp.zeros(dim)
y = y.at[-1].set(4)  # Last dimension outlier in N(0,1)
p_x = dist.Normal(jnp.zeros(dim), 1)
model_args = [y, p_x, probs]
mcmc.run(mcmc_key, *model_args)
samples = mcmc.get_samples()


from numpyro.infer import Predictive
predictive = Predictive(model, samples, infer_discrete=True)
key, subkey = random.split(key)
hidden = predictive(subkey, *model_args)

One pitfall I think I fell into, was that I had a deterministic statement that relied on the value of the discrete latent variable (the "adjusted" in the earlier models). I assume this cannot work as expected when using enumeration? Now, I instead infer the discrete latents after the fact, and then apply my deterministic transform outside the model.

I think while infer_discrete, you can also obtain deterministic sites. Maybe add a flag to your model to decide when to get deterministic sites?

def model(..., get_deterministic=False):
    if get_deterministic:
        numpyro.deterministic(...)

this cannot work as expected when using enumeration

I think so. Under enumeration, your deterministic sites will have enumerated dimensions. It is better to get deterministic values when infer_discrete.

1 Like