This is likely an issue with Predictive
– see this post and this associated issue. In particular, when I attempt to use Predictive
, I get an error like
/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/predictive.py in _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace)
46 else:
47 return {site: torch.stack([s[site] for s in collected]).reshape(shape)
---> 48 for site, shape in return_site_shapes.items()}
which makes sense because the shape
is itself an rv.
If your only goal is simulation, a model like
def model():
N = pyro.sample("N", dist.Poisson(10.))
noise = pyro.sample("noise", dist.Normal(0, 1).expand((int(N),)))
randsum = pyro.deterministic("rand_sum", noise.sum())
return randsum
should be fine. As far as sampling from the prior predictive, since this is just calling model()
and we can do this with
samples = torch.stack(list(map(lambda x: model(), range(100))))
then replacing map
with something like
import multiprocessing
pool = multiprocessing.Pool(None)
...
samples = torch.stack(list(pool.map(lambda x: model(), range(100))))
would be okay.