Generating posterior predictive samples

Hey! I really struggle with generating posterior predictives. I have a slightly more complex model, but have the same problems on easier versions, for example the following:

Data
simple gaussian

samples = np.array(Normal(0, 2).sample(PRNGKey(42), (1000, )))

Model

def model(y=None):
    mu = numpyro.sample('mu', Normal(0, 5))
    sigma = numpyro.sample('sigma', HalfNormal(5))
    
    with numpyro.plate('data', len(y)):
        numpyro.sample('obs', Normal(mu, sigma), obs=y)

With my fitted model (using NUTS), I would now like to generate synthetic datasets for which I can calculate some statistics as well as use them to calculate some target values.

So basically I would like to generate n models, \mathcal{N_i}(mu_i, \sigma_i) that are initialised by drawing \mu_i and \sigma_i from the posterior distributions of \mu and \sigma. From each \mathcal{N_i} I would then like to generate m samples. Leaving me with m x n samples in total.

Model call
I think very standard:

rng_key = PRNGKey(0)
rng_key, rng_key_ = jax.random.split(rng_key)

# Run NUTS.
kernel = NUTS(model)
num_samples = 2000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
    rng_key_, y=samples
)
mcmc.print_summary()
samples_1 = mcmc.get_samples()

I then tried to generate some posterior predictives using:

predictive = Predictive(model, samples_1)
predictions = predictive(rng_key_)

This did not work, because len(y) for None gives an Error. I then deleted the pyro.plate statement (I read this was good practice to use, for optimisation purposes in the background, other then that I don’t have a reason why I use it). This way I was able to generate a max of 2000 samples (num_samples). I further do not know how these samples are generated. Using 1 draw of parameters and therefore 1 model or 2000?

Thank you very much!

P.S.: I read other questions and tried to work through the Divorce rate example, but I could not transfer these to my problem.

You can change this to plate('data', len(y) if y is not None else 1000)

1 Like

Hey @fehiepsi!
Thank you very much. That helped me a lot. In my actual model it still did not work because I had this in the beginning:

def model(A=None, B=None, C=None):
    obs = jnp.column_stack([A, B, C])
    if A == None:
        num_samples = 1000

This leads to obs being [nan, nan, nan] and therefore not being None.

Have it now working though.
Thanks again