@fehiepsi Unless I’m missing something, it seems this is possible? This post did it using SVI with hiding out discrete variables, and this one too.
I’m using a setup like this – I have two functions of my model, one with the discrete latent variable site hidden/blocked
def my_model(data):
with numpyro.plate("L", L):
with numpyro.handlers.block(), numpyro.handlers.seed(rng_seed=jrng_key):
c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
... #rest of the model
and one where the discrete latent variable isn’t blocked/hidden.
def my_model_no_block_discrete(data):
with numpyro.plate("L", L):
c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
... #rest of the model
I use SVI AutoDelta to infer all the other variables (except the discrete latent variable)
auto_guide = infer.autoguide.AutoDelta(my_model)
optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = infer.SVI(my_model, auto_guide, optimizer, loss=infer.TraceEnum_ELBO(max_plate_nesting=2))
svi_result = svi.run(jrng_key, 2000, data)
# get posterior samples
predictive = infer.Predictive(auto_guide, params=svi_result.params, num_samples=2000)
samples = predictive(jrng_key, data)
and then use infer.Predictive
to get samples for the discrete latent variable
discrete_predictive = infer.Predictive(my_model_no_block_discrete, samples, infer_discrete=True)
discrete_samples = discrete_predictive(jrng_key, data)
This seems to give decent results, although not perfect for the discrete variable – precision and recall both are ~ 0.8.
Also this maybe a slightly convoluted setup, let me know if this can be simplified!