Inference of Mixture of LogNormals


I am trying to build a model of LogNormal mixtures.
Here is the code:

K = 2
def betabin_mixture_lognormal_model_1group(total_obs: int, non_zero_obs: int, non_zero_data: jnp.ndarray):
    sample_var = jnp.log(non_zero_data).var()
    theta = numpyro.sample('theta_prior', dist.Beta(1,1))
    numpyro.sample('zero_prop_obs', dist.Binomial(total_count=total_obs, probs=theta), obs=non_zero_obs)
    comp_weights = numpyro.sample('components_weights', dist.Dirichlet(jnp.ones(K)))
    assignment = dist.Categorical(probs=comp_weights)
    with numpyro.plate('components', K):
        mu = numpyro.sample('mu_hyperprior', dist.Normal(loc=jnp.log(non_zero_data).mean(), scale=sample_var))
        sigma = numpyro.sample('sigma_hyperprior', dist.HalfNormal(scale=sample_var))
    with numpyro.plate("data", len(non_zero_data)):
        numpyro.sample(f"non_zero_data_obs", dist.MixtureSameFamily(assignment, dist.LogNormal(loc=mu, scale=sigma)), obs=non_zero_data)

The inference works, but:

  • It is very slow. For 2 components, 6k data points, 3 chains and 5000 samples it takes more than a minute
  • n_eff is very low. Rarely >1000.

I am fairly new to Mixture models and might now some tricks to make inference more performant.

Do you use parallel enumeration? Have you check the tutorial on enumeration with discrete variables? One other way to increase it is to avoid Dirichlet and use its approximation. Check the ProdLDA tutorial. You can approximate the Dirichlet distribution with a logistic-normal distribution (more precisely, this is softmax-normal distribution).