You can also use Predictive
as a convenience utility to draw samples from the prior by passing an empty dict to posterior_samples
argument, which will essentially do what @martinjankowiak’s snippet above is doing. An additional advantage is that if all the batch dimensions are annotated correctly with pyro.plate
, you can use parallel=True
to draw a single vectorized sample which might be faster for more complex models.
def model(x, y=None):
...
pyro.sample('y', dist.Normal(0., 1.), obs=y)
# draw 100 samples from the prior
prior_samples = Predictive(model, {}, num_samples=100)(x)
print(prior_samples)