Latent categorical

It seems numpyro can model latent integers using funsor, great. In this example I want to model 2 latent groups using Dirichlet and Categorical.

I am struggling with the ii variable, which needs to be length N=100. I tried various shape parameters but I failed.

I previously got this to work in pymc3 but am curious how to do it in numpyro.

# pip install funsor

import jax.numpy as np
import numpyro
import numpy.random as npr
import numpyro.distributions as dist
from jax import random, ops
from numpyro.infer import MCMC, NUTS

N = 100
x = npr.normal(size=N)
alpha = np.array([1., 1.]) # prior
group = npr.choice(2, size=N, p=[.7, .3]) # unequal groups
y = 10 + 4 * (group == 0) + x + npr.normal(size=N)

def model(x=None, y=None, alpha=None):
  a = numpyro.sample('a', dist.Normal([10, 0], 10))
  b = numpyro.sample('b', dist.Normal(0, 1))
  sigma = numpyro.sample('sigma', dist.Gamma(1, 1))
  theta = numpyro.sample('theta', dist.Dirichlet(alpha))
  # ii should be length N
  ii = np.ones(N)
  # ii = numpyro.sample('ii', dist.Categorical(theta))
  for i in range(N):
    ops.index_update(ii, i, numpyro.sample('ii' + str(i), dist.Categorical(theta)))
  pred = numpyro.sample('pred', a[ii] + b * x, sigma, obs=y)

# print(dist.Dirichlet(np.repeat(2, 7)).sample(random.PRNGKey(1805), (10,)))

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=1)
rng_key = random.PRNGKey(0), x=x, y=y, alpha=alpha) 

posterior = mcmc.get_samples()

@dirknbr In Pyro, you can use plate to declare batch dimensions, like this

def model(x=None, y=None, alpha=None):
    a = numpyro.sample('a', dist.Normal(np.array([10, 0]), 10).to_event(1))
    # or with numpyro.plate("groups", 2):
    #     a = numpyro.sample('a', dist.Normal(np.array([10, 0]), 10))
    with numpyro.plate("N", N):
        ii = numpyro.sample('ii', dist.Categorical(theta))
        numpyro.sample('pred', dist.Normal(a[ii] + b * x, sigma), obs=y)

NUTS will marginalize the latent ii variable so you won’t see it in the output (we are working on a utility to infer those marginalized variables). If you also want to obtain ii during MCMC run, you can use DiscreteHMCGibbs kernel, which should have similar functionality as PyMC3.

1 Like

Thank you, I did what you suggested but got this error now (extract)

value_scaled = (value - self.loc) / self.scale
TypeError: unsupported operand type(s) for -: 'JVPTracer' and 'list'

Did you use np.array([10, 0]) at site a or [10, 0] at site a? (the error said that some of your parameters is a list)

You are right, my mistake. Fixed. Thank you.

def model36(a, b, probs, k=None):
    # priors: th, n
    n = numpyro.sample('n', dist.Categorical(probs=probs))
    th = numpyro.sample('th', dist.Beta(a, b))
    # observation
    size = len(k)
    with numpyro.plate(f'i=1..{size}', size=size):
        obs = numpyro.sample('k', dist.Binomial(total_count=n, probs=th), obs=k)

Trying a very simple model with DiscreteHMCGibbs()

a, b = 1, 1
k = jnp.array([16, 18, 22, 25, 27])
nmax = 500
probs = jnp.array([1.]*nmax) / nmax
kernel = numpyro.infer.DiscreteHMCGibbs(NUTS(model36), modified=True)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=5000), a=b, b=b, probs=probs, k=k)


                mean       std    median      5.0%     95.0%     n_eff     r_hat
         n    159.23    101.53    131.00     33.00    307.00     13.25      1.01
        th      0.20      0.13      0.17      0.04      0.39     16.38      1.01

which shows very low n_eff for both n and th with r_hat a little bit different from 1.0, which makes me wonder what to do.

Definitely, with NUTS, the result was satisfactory, but as you mentioned we do not have posterior samples of ‘n’.

What happens if n is modeled to be a continuous variable? I found it in a pymc3 example (

def modelu(a, b, nmax, k=None):
    u = numpyro.sample('u', dist.Uniform())
    n = u * nmax
    numpyro.deterministic('n', n)
    th = numpyro.sample('th', dist.Beta(a, b))
    size = len(k)
    with numpyro.plate(f'i=1..{size}', size=size):
        obs = numpyro.sample('k', dist.Binomial(total_count=n, probs=th), obs=k)

and when it was used with NUTS-MCMC, the result was kind of good:

                mean       std    median      5.0%     95.0%     n_eff     r_hat
        th      0.20      0.14      0.15      0.04      0.42    609.16      1.00
         u      0.36      0.25      0.28      0.06      0.75    607.43      1.00

and the KDE posterior density looks almost similar to the one with DiscreteHMCGibbs, even though summary statistics are a little bit different.

I just guess numpyro does something inside during NUTS MCMC.

Do you think this is OK to get a joint samples of (n, th) ?

Any comment will be very much appreciated. Thanks in advance.

Hi @yongduek, I think that result is expected for HMC within Gibbs, especially for your setup (the support of n has size 500 so it took at least 500 MCMC steps to walkthrough possible values of n assuming a new value of n is drawn in each MCMC step - so it is extremely ineffective to perform Gibbs update here). We just added infer_discrete, which can be used after running MCMC with enumeration

def f(th_sample):
    predictive = Predictive(infer_discrete(config_enumerate(
        numpyro.handlers.condition(model36, th_sample)),
        first_available_dim=-2, temperature=0), {}, num_samples=1, return_sites=["n"])
    n = predictive(jax.random.PRNGKey(0), a, b, probs, k)
    return n["n"][0]

jax.vmap(f, mcmc.get_samples())

Please let us know if it gives you expected result. This is a new cool feature but the API is a bit inconvenient to use (we’ll enhance it through feedback).

Using infer_discrete=True in Predictive handles this tricky posterior better than either DiscreteHMCGibbs or the continuous prior on n.

The code for the figures can be found in the repo. I am not sure how many chains @yongduek ran with for his example when using the Uniform prior on n, but I have a lot of divergences and a r_hat > 1. Yet this seems consistent with the example from PyMC3



  • figure with x as the title displays trace for n

Continous n

1 Like

Hey everyone!

I happened to have a similar question wrt. to a simple HMM:

def model_scan(trans, obs=None, n=None):
    """ Vanilla HMM with Gaussian observations using numpyro's scan """
    s0 = numpyro.sample(
        dist.Categorical(np.ones(2) / 2),
    probs_trans = numpyro.sample(
        "trans", dist.Dirichlet(trans).to_event(1)
    def transition(s_prev, y_obs):
        s = numpyro.sample("s", dist.Categorical(probs_trans[s_prev]))
        y = numpyro.sample("y", dist.Normal(mu[s], sigma[s]), obs=y_obs)
        return s, (s, y)

    _, (s, y) = scan(
    return (s, y)

I have tried a number of strategies:

  1. Using Predictive
    1. With infer_discrete=False (got samples for s but they don’t seem conditioned on the observations. What are they actually?)
    2. With infer_discrete=True (errors)
  2. Using contrib.funsor.infer_discrete
    1. On the model as is (errors the same way)
    2. On the model without continuous variables (errors differently)
    3. On the model with contrib.funsor.markov for-loop instead of contrib.control_flow.scan (actually works)

Here is a notebook with a self-contained example and all of the attempts listed above.

Admittedly, it lacks the combination of Predictive and infer_discrete as shown by @fehiepsi above, but I could neither quite understand it nor make it work naively :no_mouth:

I was wondering if I am misusing anything here? I am glad to see contrib.funsor.infer_discrete working with markov loop, but I wonder if the issue using it with numpyro.contrib.control_flow.scan expected?

Besides these questions, big thanks for all the work on numpyro (and funsor), it all looks great and very curious!

Great questions and insights - thanks for trying various options, @renat.s!

  1. With infer_discrete=False (got samples for s but they don’t seem conditioned on the observations. What are they actually?)

This is just a usual usage of Predictive: running the model, conditioned on posterior samples if provided, and collect samples. So those s are not coming from the joint posterior of s and other latent variables.

  1. With infer_discrete=True (errors)

The PR which attempts to support this feature is here. The difference between markov and scan is scan uses a parallel-scan algorithm to speed up performance on GPUs. Currently, it is not working yet. :frowning:

Thanks for the quick reply and the clarifications @fehiepsi!

Got it. I missed this PR. I am yet to familiarise myself with it properly, but naively I see two seemingly separate issues:

  1. that captured by test_scan_hmm_smoke (hitting an assertion in _get_support_value_contraction, if the model has no latent continuous variables), and
  2. that from scenarios 1.2 or 2.1 above when Predictive(... infer_discrete=True) or infer_discrete are applied to a model containing both discrete and continuous latent variables, when I hit the first assertion in numpyro.distributions.continuous.sample (captured in the notebook), making me assume I might be misusing both calls somehow. Am I missing something? (Although I understand, the first issue would apply anyway…)