Hi all,
I am trying to create a model for a dataset where the probability for an observation being greater than zero follows a Bernoulli distribution. Then for all sites which are non-zero, I am modelling the observed value using a LogNormal distribution. My dataset has shape (n_obs, n_features) and I can consider each feature as independent.
Here is the model I have defined
import numpyro
import numpyro.distributions as dist
from jax import numpy as jnp
import numpy as np
from jax.random import PRNGKey
from numpyro.infer import MCMC, NUTS
def model(obs):
num_obs, num_features=obs.shape
with numpyro.plate("feature", num_features):
prob_priors = numpyro.sample(
"p", dist.Beta(0.5, 0.5)
)
vol_stds = numpyro.sample(
"vol_std", dist.InverseGamma(1,1.0)
)
vol_mus = numpyro.sample(
"vol_mu", dist.Normal(0, 1)
)
locs = obs > 0
with numpyro.plate("obs", num_obs):
stocked = numpyro.sample(
"stocked", dist.Bernoulli(probs=prob_priors), obs=locs,
)
with numpyro.handlers.mask(mask=locs):
volume = numpyro.sample(
"volume", dist.LogNormal(vol_mus, vol_stds), obs=obs,
)
return volume, stocked
def generate_data(num_obs, probs, vol_mus, vol_stds):
num_features = len(probs)
locs = np.random.binomial(1,probs, size=(num_obs, num_features))
volumes = np.random.lognormal(vol_mus, vol_stds, size=(num_obs, num_features))
volumes[locs == 0] = 0
return volumes
obs = generate_data(10000, [0.1, 0.2, 0.3, 0.4, 0.5], [0, 1, 2, 3, 4], [1, 1,1,1,1])
kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=1000, num_warmup=1000)
mcmc.run(rng_key=PRNGKey(0), obs=obs)
mcmc.print_summary()
I’ve found that this model tends to scale poorly with the number of observations. With this toy data, 1000 observations takes around 4s on my machine and 5000 observations takes around 5 minutes.
With my actual data, the performance is slightly worse, probably due to a poor choice of priors.
I think that the problem lies with the “stocked” observations, because the MCMC sampler runs much faster when I comment out that part of the model.
I’m fairly new to probabilistic programming, but any tips on how I could improve the definition of my model would be greatly appreciated