Automatic guides for a model with discrete latent variables like Pyro

I understand that when using Stochastic Variational Inference in Pyro, we can define auto guides like this for a model with discrete latent variables (source)

guide = AutoGuideList(my_model)
guide.append(AutoDiagonalNormal(poutine.block(model, hide=["assignment"])))

Can we do the same in Numpyro? If yes, is that the recommended way as opposed to defining manual guides, when using TraceEnumELBO?

Hi @pankajb64, currently, we don’t have support for this feature. Please feel free to make a feature request.

@fehiepsi Unless I’m missing something, it seems this is possible? This post did it using SVI with hiding out discrete variables, and this one too.

I’m using a setup like this – I have two functions of my model, one with the discrete latent variable site hidden/blocked

def my_model(data):
        with numpyro.plate("L", L):
            with numpyro.handlers.block(), numpyro.handlers.seed(rng_seed=jrng_key):
                c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
          ... #rest of the model

and one where the discrete latent variable isn’t blocked/hidden.

def my_model_no_block_discrete(data):
        with numpyro.plate("L", L):
                c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
          ... #rest of the model

I use SVI AutoDelta to infer all the other variables (except the discrete latent variable)

auto_guide = infer.autoguide.AutoDelta(my_model)
optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = infer.SVI(my_model, auto_guide, optimizer, loss=infer.TraceEnum_ELBO(max_plate_nesting=2))
svi_result =, 2000, data)
# get posterior samples
predictive = infer.Predictive(auto_guide, params=svi_result.params, num_samples=2000)
samples = predictive(jrng_key, data)

and then use infer.Predictive to get samples for the discrete latent variable

discrete_predictive = infer.Predictive(my_model_no_block_discrete, samples, infer_discrete=True)
discrete_samples = discrete_predictive(jrng_key, data)

This seems to give decent results, although not perfect for the discrete variable – precision and recall both are ~ 0.8.

Also this maybe a slightly convoluted setup, let me know if this can be simplified!

I think AutoDelta gets a point estimate for just a constant value of c. In Pyro, we enumerated over c.

Hi @fehiepsi, could you elaborate on that last comment? Are you suggesting a mistake in my approach?

I’m using AutoDelta to learn estimates for all the continuous variables in my model. AutoDelta seems to work better than AutoNormal.

Additionally, this approach of using an automatic guide to learn the remaining variables and then inferring the discrete variable separately works better than defining a manual guide which has a discrete variable included – this guide in particular. Although, the results are still not as good as they can be, so any suggestions for how I could improve things would be much appreciated.

When c is enumerated, autoguide will target the density p(x) = sum(p(x,c=0),p(x,c=2),p(x,c=3)). When c is a constant, as it will be under handlers.block(), handlers.seed(...) context manager, autoguide will target the density p(x,c=constant_generated_by_the_prior_under_handlers_seed). This way, your autoguide will find the best parameters corresponding to that constant.

Instead of using block, you can do

def model(c=None):
    c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'}, obs=c)

svi = SVI(model, guide, ...)
svi_results =, c=0)  # e.g.
c_samples = infer_discrete(...)
c_new = pick_the_most_popular_one_from(c_samples), c=c_new, init_state=svi_results.state)
c_samples_new = infer_discrete(...)

But I’m not sure if it helps much. If your target density is multi-modal, where each value of c corresponds to a model (like in Gaussian Mixture Model), then I think current autoguides might not be helpful.

@fehiepsi thanks, that makes sense. It is indeed the case that the posterior is like a mixture model with different modes corresponding to different values of c.

However, I did find that if I use a manual guide which looks very similar to AutoNormal except that it has constraints that certain params be positive, then I am able to correctly learn all the continuous variables using SVI and also correctly learn the discrete latent variable c using the manual enumeration method you described above.

So my current setup is like this

def model(c=None):
    with numpyro.plate("L", L):
        c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'}, obs=c)

svi = SVI (model, custom_guide, loss=Trace_ELBO)
svi_c0 =, c=0)
svi_c1 =, c=1)

Lambda = #matrix of log posteriors for discrete variable c, size (#c x L)
c_probabilities = Lambda - logsumexp(Lambda, axis=0)

I’m using Trace_ELBO instead of TraceEnum_ELBO because the discrete variable is not actually getting enumerated in the model. If I modify my manual guide to include c, then the results I get for c from SVI aren’t great. So instead I’m using manual enumeration as above, which gives me precision/recall of 0.9+ for my dataset. I’m guessing therefore that including infer={'enumerate': 'parallel'} in the model is pointless.

I’m very new to numpyro, so if there is a better way to do this, please let me know. I’m actually very surprised that automatic enumeration is not working but manual enumeration is.

I’m not sure if there is a better way for it. Two approaches are different. It depends on datasets I guess. Maybe your manual guide is not flexible enough to fit the posterior (with enumeration).

Thanks. This seems good enough for my purpose, both in terms of accuracy and runtime.