Hi
I am trying to build a model of LogNormal mixtures.
Here is the code:
K = 2
def betabin_mixture_lognormal_model_1group(total_obs: int, non_zero_obs: int, non_zero_data: jnp.ndarray):
sample_var = jnp.log(non_zero_data).var()
theta = numpyro.sample('theta_prior', dist.Beta(1,1))
numpyro.sample('zero_prop_obs', dist.Binomial(total_count=total_obs, probs=theta), obs=non_zero_obs)
comp_weights = numpyro.sample('components_weights', dist.Dirichlet(jnp.ones(K)))
assignment = dist.Categorical(probs=comp_weights)
with numpyro.plate('components', K):
mu = numpyro.sample('mu_hyperprior', dist.Normal(loc=jnp.log(non_zero_data).mean(), scale=sample_var))
sigma = numpyro.sample('sigma_hyperprior', dist.HalfNormal(scale=sample_var))
with numpyro.plate("data", len(non_zero_data)):
numpyro.sample(f"non_zero_data_obs", dist.MixtureSameFamily(assignment, dist.LogNormal(loc=mu, scale=sigma)), obs=non_zero_data)
The inference works, but:
- It is very slow. For 2 components, 6k data points, 3 chains and 5000 samples it takes more than a minute
n_eff
is very low. Rarely >1000.
I am fairly new to Mixture models and might now some tricks to make inference more performant.