Credible intervals for parameters using SVI

Hi there.

I’m trying to get some parameter and credible interval estimates for a mixture model using numpyro. The model looks like this:

@config_enumerate
def discrete_mixture_model(K, X=None):
    
    N, D = X.shape
    cluster_proba = numpyro.sample('cluster_proba', dist.Dirichlet(0.5 * jnp.ones(K))) 
        
    with numpyro.plate('components', D):
        with numpyro.plate("cluster", K):
            phi = numpyro.sample('phi', dist.Beta(2.0, 2.0)) 

    with numpyro.plate('data', N):
        
        assignment = numpyro.sample('assignment', dist.CategoricalProbs(cluster_proba)) 
        
        numpyro.sample(
            'obs', 
            MultivariateBernoulli(phi[assignment, :]), 
            obs=X,
        )

And the fitting looks like this:

k = 3
# X = some data

global_model = numpyro.handlers.block(
    numpyro.handlers.seed(discrete_mixture_model, jax.random.PRNGKey(0)),
    hide_fn=lambda site: site["name"]
    not in ["cluster_proba", "components", "cluster", "phi"]
)

init_vals = {
    "cluster_proba": jnp.ones(k) / float(k),
    "phi": np.random.rand(k, disc_data["dataframe"].to_numpy().shape[1])
}

guide = ag.AutoDelta(
    global_model,
    init_loc_fn=init_to_value(values=init_vals)
)

elbo = TraceEnum_ELBO()

optimizer = numpyro.optim.Adam(step_size=0.005)
svi = SVI(discrete_mixture_model, guide, optimizer, loss=elbo)
svi_result = svi.run(jax.random.PRNGKey(0), 10000, X=X ,K=k)

The estimates for the parameters are fine, but I will need an estimate of the credible intervals (or the posterior distribution for the parameters) and I can’t figure out how to do it with SVI. I’ve had a go at using Predictive to generate some posterior samples:

params = svi_result.params
predictive=Predictive(discrete_mixture_model, guide=guide, params=params, num_samples=N)
rng_key, rng_subkey = jax.random.split(key=rng_key)
posterior_samples = predictive(rng_subkey, K=K, N=N, D=D)
predictive_post = Predictive(guide, posterior_samples, params=params, num_samples= N)
samples = predictive_post(jax.random.PRNGKey(1), K=K, N=N, D=D)

but this just gives me many copies of the same sample. There are some threads on here and github about that but not had any joy solving it so far. If anyone has any insights I would be most grateful. Thanks :slight_smile:

Hi @jim, AutoDelta guide gives you point estimates of the variables. You might want to use AutoNormal instead.

Thanks! For the AutoNormal guide does it matter that the parameters are not normally distributed variables? In fact, the phis are between 0 and 1.

Yes, it doesn’t matter. We transform variables to unconstrained domains under the hood.

1 Like

Awesome, thanks :slight_smile: . I was getting confused by what is returned by svi_result.params - for the AutoDelta guide the params are good estimates of the model parameters. For the AutoNormal guide they aren’t but sampling from the posterior allows me to estimate the parameters and that works fine with the AutoNormal guide. I should probably read the manual!