I am using Predictive to predict Y for a given set of parameters.
predictive = numpyro.infer.Predictive(model, samples, parallel=True) pred = predictive(rng_key, X=X, D_Y=D_Y, Y=None, D_H=D_H, prior_std=prior_std)
Here samples are a single set of sampled parameters and X has many samples (and I need to make sequential predictions so I will loop over timesteps for the same set of parameters).
I am finding it surprisingly slow to predict values and is only using a single core. Do I need to use vmap to split data up to get parallelization usage for a single set of parameters?
Maybe it would be worth exporting the parameters to a model clone in jax/pytorch for speedup?