How can I do forwards (ancestral) sampling from unconditional joint ("prior")?

I’ve implemented the simple 1d Gaussian example from intro to pyro thus:

def model(hparams, y=None):
    prior_mean, prior_sd, obs_sd = hparams
    theta = numpyro.sample("theta", dist.Normal(prior_mean, prior_sd))
    return numpyro.sample("y", dist.Normal(theta, obs_sd), obs=y)

I can do posterior inference (of theta) in this using MCMC thus:

mu = 8.5; tau = 1.0; sigma = 0.75;
hparams = (mu, tau, sigma)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000)
mcmc.run(rng_key_, hparams, y)

Now I want to generate new samples from the unconditional joint. I tried this

data = model(hparams)
print(data)

but get this error:

What does this mean?

I also tried model(rng_key, hparams) and model(hparams, None) and model(rng_key, hparams, None) but none of these work.

hello @murphyk
you’re missing a rng key. please see the FAQ

I tried data = model(rng_key_, hparams) but get the error not enough values to unpack (expected 3, got 2).

So then I tried data = model(rng_key_, hparams, None) but then I get the error model() takes from 1 to 2 positional arguments but 3 were given.

What I would like to do is define the model without specifying the observations (so model(hparams), not model(hparams, obs)), and then call numpyro.condition, like in pyro, but I don’t see how to do that in NumPyro (I assume obs=None is like not conditioning?).

please see this code snippet:

import jax
import jax.random as random
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC


def model(hparams, y=None):
    prior_mean, prior_sd, obs_sd = hparams
    theta = numpyro.sample("theta", dist.Normal(prior_mean, prior_sd))
    return numpyro.sample("y", dist.Normal(theta, obs_sd), obs=y)

mu = 8.5; tau = 1.0; sigma = 0.75;
hparams = (mu, tau, sigma)

with numpyro.handlers.seed(rng_seed=0):
    y = model(hparams)

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000)
mcmc.run(random.PRNGKey(0), hparams, y)

# alternatively
conditioned_model = numpyro.handlers.condition(model, {'y': y})
nuts_kernel = NUTS(conditioned_model)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000)
mcmc.run(random.PRNGKey(0), hparams)

Thanks, that all works! :slight_smile:

How can I best draw multiple ancestral samples?
Currently I do this:

nsamples = 5
data_list = []
with numpyro.handlers.seed(rng_seed=0):
  for i in range(nsamples):
    out = model(hparams)
    data_list.append(out)

the simplest way is to wrap your model in a plate:

num_samples = 10

with numpyro.handlers.seed(rng_seed=0):
    with numpyro.plate("multiple_samples", num_samples):
        y = model(hparams)