Poor scaling on number of observations with sparse data

Hi all,
I am trying to create a model for a dataset where the probability for an observation being greater than zero follows a Bernoulli distribution. Then for all sites which are non-zero, I am modelling the observed value using a LogNormal distribution. My dataset has shape (n_obs, n_features) and I can consider each feature as independent.

Here is the model I have defined

import numpyro
import numpyro.distributions as dist
from jax import numpy as jnp
import numpy as np
from jax.random import PRNGKey
from numpyro.infer import MCMC, NUTS    
def model(obs):
    num_obs, num_features=obs.shape
    with numpyro.plate("feature", num_features):
        prob_priors = numpyro.sample(
            "p", dist.Beta(0.5, 0.5)
        )
        vol_stds = numpyro.sample(
            "vol_std", dist.InverseGamma(1,1.0)
        )
        vol_mus = numpyro.sample(
            "vol_mu", dist.Normal(0, 1)
        )
        locs = obs > 0
        with numpyro.plate("obs", num_obs):
            stocked = numpyro.sample(
                "stocked", dist.Bernoulli(probs=prob_priors), obs=locs,
            )
            with numpyro.handlers.mask(mask=locs):
                volume = numpyro.sample(
                    "volume", dist.LogNormal(vol_mus, vol_stds), obs=obs,
                )
    return volume, stocked

    def generate_data(num_obs, probs, vol_mus, vol_stds):
        num_features = len(probs)
        locs = np.random.binomial(1,probs, size=(num_obs, num_features))
        volumes = np.random.lognormal(vol_mus, vol_stds, size=(num_obs, num_features))
        volumes[locs == 0] = 0
        return volumes


    obs = generate_data(10000, [0.1, 0.2, 0.3, 0.4, 0.5], [0, 1, 2, 3, 4], [1, 1,1,1,1])
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_samples=1000, num_warmup=1000)
    mcmc.run(rng_key=PRNGKey(0), obs=obs)

    mcmc.print_summary()

I’ve found that this model tends to scale poorly with the number of observations. With this toy data, 1000 observations takes around 4s on my machine and 5000 observations takes around 5 minutes.
With my actual data, the performance is slightly worse, probably due to a poor choice of priors.

I think that the problem lies with the “stocked” observations, because the MCMC sampler runs much faster when I comment out that part of the model.

I’m fairly new to probabilistic programming, but any tips on how I could improve the definition of my model would be greatly appreciated

in terms of runtime performance if your data is very sparse you might possibly benefit from expressing your observe statements in terms of a hand-coded factor statement, i.e. indexing instead of masking

it’s also possible that dist.Bernoulli might be more stable if specified in terms of logits instead of probs

it’s also generally a good idea to try 64 bit precision

Thanks for the advice. I was able to speed up the runtime by 2x by using logits instead of the probability.

By modelling the count of the number of observations as a binomial distribution, I’ve also seen further improvements in the convergence. I think I will stick with this as I would like to extend my model to consider correlations between each feature, which seems more straightforward using the counts across all observations.

Here is what I have so far

def model(obs: np.array):
    num_obs, num_features = obs.shape

    with numpyro.plate("features", num_features):
        prob_priors = numpyro.sample(
            "p", dist.Normal(0,1)
        )
        vol_stds = numpyro.sample(
            "vol_std", dist.InverseGamma(1,1.0)
        )
        vol_mus = numpyro.sample(
            "vol_mu", dist.Normal(0, 1)
        )
        if obs is not None:
            locs = obs > 0
            counts = jnp.sum(locs, axis=0)
        else:
            locs = None
            counts = None
        stocked = numpyro.sample("stocked", dist.Binomial(num_obs, logits=prob_priors), obs=counts)
        with numpyro.plate("obs", num_obs):
            with numpyro.handlers.mask(mask=obs > 0):
                volume = numpyro.sample(
                    "volume", dist.LogNormal(vol_mus, vol_stds), obs=obs,
                )