I am trying to create a model for a dataset where the probability for an observation being greater than zero follows a Bernoulli distribution. Then for all sites which are non-zero, I am modelling the observed value using a LogNormal distribution. My dataset has shape (n_obs, n_features) and I can consider each feature as independent.
Here is the model I have defined
import numpyro import numpyro.distributions as dist from jax import numpy as jnp import numpy as np from jax.random import PRNGKey from numpyro.infer import MCMC, NUTS def model(obs): num_obs, num_features=obs.shape with numpyro.plate("feature", num_features): prob_priors = numpyro.sample( "p", dist.Beta(0.5, 0.5) ) vol_stds = numpyro.sample( "vol_std", dist.InverseGamma(1,1.0) ) vol_mus = numpyro.sample( "vol_mu", dist.Normal(0, 1) ) locs = obs > 0 with numpyro.plate("obs", num_obs): stocked = numpyro.sample( "stocked", dist.Bernoulli(probs=prob_priors), obs=locs, ) with numpyro.handlers.mask(mask=locs): volume = numpyro.sample( "volume", dist.LogNormal(vol_mus, vol_stds), obs=obs, ) return volume, stocked def generate_data(num_obs, probs, vol_mus, vol_stds): num_features = len(probs) locs = np.random.binomial(1,probs, size=(num_obs, num_features)) volumes = np.random.lognormal(vol_mus, vol_stds, size=(num_obs, num_features)) volumes[locs == 0] = 0 return volumes obs = generate_data(10000, [0.1, 0.2, 0.3, 0.4, 0.5], [0, 1, 2, 3, 4], [1, 1,1,1,1]) kernel = NUTS(model) mcmc = MCMC(kernel, num_samples=1000, num_warmup=1000) mcmc.run(rng_key=PRNGKey(0), obs=obs) mcmc.print_summary()
I’ve found that this model tends to scale poorly with the number of observations. With this toy data, 1000 observations takes around 4s on my machine and 5000 observations takes around 5 minutes.
With my actual data, the performance is slightly worse, probably due to a poor choice of priors.
I think that the problem lies with the “stocked” observations, because the MCMC sampler runs much faster when I comment out that part of the model.
I’m fairly new to probabilistic programming, but any tips on how I could improve the definition of my model would be greatly appreciated