Hi there.
I’m trying to get some parameter and credible interval estimates for a mixture model using numpyro. The model looks like this:
@config_enumerate
def discrete_mixture_model(K, X=None):
N, D = X.shape
cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(0.5 * jnp.ones(K)))
with numpyro.plate('components', D):
with numpyro.plate("cluster", K):
phi = numpyro.sample('phi', dist.Beta(2.0, 2.0))
with numpyro.plate('data', N):
assignment = numpyro.sample('assignment', dist.CategoricalProbs(cluster_proba))
numpyro.sample(
'obs',
MultivariateBernoulli(phi[assignment, :]),
obs=X,
)
And the fitting looks like this:
k = 3
# X = some data
global_model = numpyro.handlers.block(
numpyro.handlers.seed(discrete_mixture_model, jax.random.PRNGKey(0)),
hide_fn=lambda site: site["name"]
not in ["cluster_proba", "components", "cluster", "phi"]
)
init_vals = {
"cluster_proba": jnp.ones(k) / float(k),
"phi": np.random.rand(k, disc_data["dataframe"].to_numpy().shape[1])
}
guide = ag.AutoDelta(
global_model,
init_loc_fn=init_to_value(values=init_vals)
)
elbo = TraceEnum_ELBO()
optimizer = numpyro.optim.Adam(step_size=0.005)
svi = SVI(discrete_mixture_model, guide, optimizer, loss=elbo)
svi_result = svi.run(jax.random.PRNGKey(0), 10000, X=X ,K=k)
The estimates for the parameters are fine, but I will need an estimate of the credible intervals (or the posterior distribution for the parameters) and I can’t figure out how to do it with SVI. I’ve had a go at using Predictive
to generate some posterior samples:
params = svi_result.params
predictive=Predictive(discrete_mixture_model, guide=guide, params=params, num_samples=N)
rng_key, rng_subkey = jax.random.split(key=rng_key)
posterior_samples = predictive(rng_subkey, K=K, N=N, D=D)
predictive_post = Predictive(guide, posterior_samples, params=params, num_samples= N)
samples = predictive_post(jax.random.PRNGKey(1), K=K, N=N, D=D)
but this just gives me many copies of the same sample. There are some threads on here and github about that but not had any joy solving it so far. If anyone has any insights I would be most grateful. Thanks