Samples from prior distribution

Is there an elegant way to extract samples from the prior distributions in a pyro model? All the examples I’ve seen show the posterior samples using MCMC().get_samples(), and posterior predictive and prior predictive samples using the Predictive class?

if your model has return values you can directly call it and get those:

def model():
    pyro.sample("z", ...)
    x = pyro.sample("x", ...)
    return x

x = model()

of course this will only get you the returned site(s). if you want all the sites, you can use pyro.poutine.trace:

model_trace = pyro.poutine.trace(model).get_trace(model_args)

# inspect the structure of model_trace to pull out what you want, e.g.
for name, site in model_trace.nodes.items():
    if site["type"] == "sample":
        print(name, site["value"])
1 Like

You can also use Predictive as a convenience utility to draw samples from the prior by passing an empty dict to posterior_samples argument, which will essentially do what @martinjankowiak’s snippet above is doing. An additional advantage is that if all the batch dimensions are annotated correctly with pyro.plate, you can use parallel=True to draw a single vectorized sample which might be faster for more complex models.

def model(x, y=None):
  ...
  pyro.sample('y', dist.Normal(0., 1.), obs=y)


# draw 100 samples from the prior
prior_samples = Predictive(model, {}, num_samples=100)(x)
print(prior_samples)