How to speed up Predictive


I am using Predictive to predict Y for a given set of parameters.

predictive = numpyro.infer.Predictive(model, samples, parallel=True)
pred = predictive(rng_key, X=X, D_Y=D_Y, Y=None, D_H=D_H, prior_std=prior_std)

Here samples are a single set of sampled parameters and X has many samples (and I need to make sequential predictions so I will loop over timesteps for the same set of parameters).

I am finding it surprisingly slow to predict values and is only using a single core. Do I need to use vmap to split data up to get parallelization usage for a single set of parameters?

Maybe it would be worth exporting the parameters to a model clone in jax/pytorch for speedup?

exporting the parameters to a model clone in jax/pytorch for speedup

Sounds reasonable to me if you are using some samplers that are slow (like Gamma distribution). If you are using GPU, setting parallel=True will be helpful. If you are using CPU, setting parallel=False will be faster. If you want to distribute the computation across your cores, you can set batch_ndim=0 and use pmap.

def get_pred(sample, rng_key):
    predictive = numpyro.infer.Predictive(model, samples, batch_ndim=0)
    return predictive(...)

pred = jax.pmap(get_pred)(samples, rng_keys)