Dear Pyro experts,
Is it possible to use subsampling in infer.Predictive
? In the toy example below I get an error, which is different to the trace/replay behaviour.
Thank you!
import torch, pyro
def model():
with pyro.plate('foo', 10):
pyro.sample('bar', pyro.distributions.Bernoulli(torch.tensor(0.0)))
return
def guide():
with pyro.plate('foo', 10, subsample_size=5):
pyro.sample('bar', pyro.distributions.Bernoulli(torch.tensor(1.0)))
return
Trace / replay does what I expect:
guide_trace = pyro.poutine.trace(guide).get_trace()
replayed_trace = pyro.poutine.trace(pyro.poutine.replay(model, trace=guide_trace)).get_trace()
print(replayed_trace.format_shapes())
print(replayed_trace.nodes['bar']['value'])
Trace Shapes:
Param Sites:
Sample Sites:
foo dist |
value 5 |
bar dist 5 |
value 5 |
tensor([1., 1., 1., 1., 1.])
Whereas infer.Predictive
gives an error (which goes away if there is no subsampling):
print(pyro.infer.Predictive(model, guide=guide, num_samples=3)())
RuntimeError: shape '[3, 10]' is invalid for input of size 15