Latent categorical

It seems numpyro can model latent integers using funsor, great. In this example I want to model 2 latent groups using Dirichlet and Categorical.

I am struggling with the ii variable, which needs to be length N=100. I tried various shape parameters but I failed.

I previously got this to work in pymc3 but am curious how to do it in numpyro.

# pip install funsor

import jax.numpy as np
import numpyro
import numpy.random as npr
import numpyro.distributions as dist
from jax import random, ops
from numpyro.infer import MCMC, NUTS

N = 100
x = npr.normal(size=N)
alpha = np.array([1., 1.]) # prior
group = npr.choice(2, size=N, p=[.7, .3]) # unequal groups
y = 10 + 4 * (group == 0) + x + npr.normal(size=N)

def model(x=None, y=None, alpha=None):
  a = numpyro.sample('a', dist.Normal([10, 0], 10))
  b = numpyro.sample('b', dist.Normal(0, 1))
  sigma = numpyro.sample('sigma', dist.Gamma(1, 1))
  theta = numpyro.sample('theta', dist.Dirichlet(alpha))
  # ii should be length N
  ii = np.ones(N)
  # ii = numpyro.sample('ii', dist.Categorical(theta))
  for i in range(N):
    ops.index_update(ii, i, numpyro.sample('ii' + str(i), dist.Categorical(theta)))
  pred = numpyro.sample('pred', a[ii] + b * x, sigma, obs=y)

# print(dist.Dirichlet(np.repeat(2, 7)).sample(random.PRNGKey(1805), (10,)))

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=1)
rng_key = random.PRNGKey(0), x=x, y=y, alpha=alpha) 

posterior = mcmc.get_samples()

@dirknbr In Pyro, you can use plate to declare batch dimensions, like this

def model(x=None, y=None, alpha=None):
    a = numpyro.sample('a', dist.Normal(np.array([10, 0]), 10).to_event(1))
    # or with numpyro.plate("groups", 2):
    #     a = numpyro.sample('a', dist.Normal(np.array([10, 0]), 10))
    with numpyro.plate("N", N):
        ii = numpyro.sample('ii', dist.Categorical(theta))
        numpyro.sample('pred', dist.Normal(a[ii] + b * x, sigma), obs=y)

NUTS will marginalize the latent ii variable so you won’t see it in the output (we are working on a utility to infer those marginalized variables). If you also want to obtain ii during MCMC run, you can use DiscreteHMCGibbs kernel, which should have similar functionality as PyMC3.

1 Like

Thank you, I did what you suggested but got this error now (extract)

value_scaled = (value - self.loc) / self.scale
TypeError: unsupported operand type(s) for -: 'JVPTracer' and 'list'

Did you use np.array([10, 0]) at site a or [10, 0] at site a? (the error said that some of your parameters is a list)

You are right, my mistake. Fixed. Thank you.

def model36(a, b, probs, k=None):
    # priors: th, n
    n = numpyro.sample('n', dist.Categorical(probs=probs))
    th = numpyro.sample('th', dist.Beta(a, b))
    # observation
    size = len(k)
    with numpyro.plate(f'i=1..{size}', size=size):
        obs = numpyro.sample('k', dist.Binomial(total_count=n, probs=th), obs=k)

Trying a very simple model with DiscreteHMCGibbs()

a, b = 1, 1
k = jnp.array([16, 18, 22, 25, 27])
nmax = 500
probs = jnp.array([1.]*nmax) / nmax
kernel = numpyro.infer.DiscreteHMCGibbs(NUTS(model36), modified=True)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=5000), a=b, b=b, probs=probs, k=k)


                mean       std    median      5.0%     95.0%     n_eff     r_hat
         n    159.23    101.53    131.00     33.00    307.00     13.25      1.01
        th      0.20      0.13      0.17      0.04      0.39     16.38      1.01

which shows very low n_eff for both n and th with r_hat a little bit different from 1.0, which makes me wonder what to do.

Definitely, with NUTS, the result was satisfactory, but as you mentioned we do not have posterior samples of ‘n’.

What happens if n is modeled to be a continuous variable? I found it in a pymc3 example (

def modelu(a, b, nmax, k=None):
    u = numpyro.sample('u', dist.Uniform())
    n = u * nmax
    numpyro.deterministic('n', n)
    th = numpyro.sample('th', dist.Beta(a, b))
    size = len(k)
    with numpyro.plate(f'i=1..{size}', size=size):
        obs = numpyro.sample('k', dist.Binomial(total_count=n, probs=th), obs=k)

and when it was used with NUTS-MCMC, the result was kind of good:

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        th      0.20      0.14      0.15      0.04      0.42    609.16      1.00
         u      0.36      0.25      0.28      0.06      0.75    607.43      1.00

and the KDE posterior density looks almost similar to the one with DiscreteHMCGibbs, even though summary statistics are a little bit different.

I just guess numpyro does something inside during NUTS MCMC.

Do you think this is OK to get a joint samples of (n, th) ?

Any comment will be very much appreciated. Thanks in advance.

Hi @yongduek, I think that result is expected for HMC within Gibbs, especially for your setup (the support of n has size 500 so it took at least 500 MCMC steps to walkthrough possible values of n assuming a new value of n is drawn in each MCMC step - so it is extremely ineffective to perform Gibbs update here). We just added infer_discrete, which can be used after running MCMC with enumeration

def f(th_sample):
    predictive = Predictive(infer_discrete(config_enumerate(
        numpyro.handlers.condition(model36, th_sample)),
        first_available_dim=-2, temperature=0), {}, num_samples=1, return_sites=["n"])
    n = predictive(jax.random.PRNGKey(0), a, b, probs, k)
    return n["n"][0]

jax.vmap(f, mcmc.get_samples())

Please let us know if it gives you expected result. This is a new cool feature but the API is a bit inconvenient to use (we’ll enhance it through feedback).

Using infer_discrete=True in Predictive handles this tricky posterior better than either DiscreteHMCGibbs or the continuous prior on n.

The code for the figures can be found in the repo. I am not sure how many chains @yongduek ran with for his example when using the Uniform prior on n, but I have a lot of divergences and a r_hat > 1. Yet this seems consistent with the example from PyMC3



  • figure with x as the title displays trace for n

Continous n

1 Like