Automatic guides for a model with discrete latent variables like Pyro

@fehiepsi Unless I’m missing something, it seems this is possible? This post did it using SVI with hiding out discrete variables, and this one too.

I’m using a setup like this – I have two functions of my model, one with the discrete latent variable site hidden/blocked

def my_model(data):
        with numpyro.plate("L", L):
            with numpyro.handlers.block(), numpyro.handlers.seed(rng_seed=jrng_key):
                c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
          ... #rest of the model

and one where the discrete latent variable isn’t blocked/hidden.

def my_model_no_block_discrete(data):
        with numpyro.plate("L", L):
                c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
          ... #rest of the model

I use SVI AutoDelta to infer all the other variables (except the discrete latent variable)

auto_guide = infer.autoguide.AutoDelta(my_model)
optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = infer.SVI(my_model, auto_guide, optimizer, loss=infer.TraceEnum_ELBO(max_plate_nesting=2))
svi_result = svi.run(jrng_key, 2000, data)
# get posterior samples
predictive = infer.Predictive(auto_guide, params=svi_result.params, num_samples=2000)
samples = predictive(jrng_key, data)

and then use infer.Predictive to get samples for the discrete latent variable

discrete_predictive = infer.Predictive(my_model_no_block_discrete, samples, infer_discrete=True)
discrete_samples = discrete_predictive(jrng_key, data)

This seems to give decent results, although not perfect for the discrete variable – precision and recall both are ~ 0.8.

Also this maybe a slightly convoluted setup, let me know if this can be simplified!