Generating posterior predictive samples

Hey! I really struggle with generating posterior predictives. I have a slightly more complex model, but have the same problems on easier versions, for example the following:

Data
simple gaussian

samples = np.array(Normal(0, 2).sample(PRNGKey(42), (1000, )))


Model

def model(y=None):
mu = numpyro.sample('mu', Normal(0, 5))
sigma = numpyro.sample('sigma', HalfNormal(5))

with numpyro.plate('data', len(y)):
numpyro.sample('obs', Normal(mu, sigma), obs=y)


With my fitted model (using NUTS), I would now like to generate synthetic datasets for which I can calculate some statistics as well as use them to calculate some target values.

So basically I would like to generate n models, \mathcal{N_i}(mu_i, \sigma_i) that are initialised by drawing \mu_i and \sigma_i from the posterior distributions of \mu and \sigma. From each \mathcal{N_i} I would then like to generate m samples. Leaving me with m x n samples in total.

Model call
I think very standard:

rng_key = PRNGKey(0)
rng_key, rng_key_ = jax.random.split(rng_key)

# Run NUTS.
kernel = NUTS(model)
num_samples = 2000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
rng_key_, y=samples
)
mcmc.print_summary()
samples_1 = mcmc.get_samples()


I then tried to generate some posterior predictives using:

predictive = Predictive(model, samples_1)
predictions = predictive(rng_key_)


This did not work, because len(y) for None gives an Error. I then deleted the pyro.plate statement (I read this was good practice to use, for optimisation purposes in the background, other then that I don’t have a reason why I use it). This way I was able to generate a max of 2000 samples (num_samples). I further do not know how these samples are generated. Using 1 draw of parameters and therefore 1 model or 2000?

Thank you very much!

P.S.: I read other questions and tried to work through the Divorce rate example, but I could not transfer these to my problem.

You can change this to plate('data', len(y) if y is not None else 1000)

1 Like