Samples from prior distribution

You can also use Predictive as a convenience utility to draw samples from the prior by passing an empty dict to posterior_samples argument, which will essentially do what @martinjankowiak’s snippet above is doing. An additional advantage is that if all the batch dimensions are annotated correctly with pyro.plate, you can use parallel=True to draw a single vectorized sample which might be faster for more complex models.

def model(x, y=None):
  ...
  pyro.sample('y', dist.Normal(0., 1.), obs=y)


# draw 100 samples from the prior
prior_samples = Predictive(model, {}, num_samples=100)(x)
print(prior_samples)
2 Likes