Subsampling in infer.Predictive

Dear Pyro experts,

Is it possible to use subsampling in infer.Predictive? In the toy example below I get an error, which is different to the trace/replay behaviour.

Thank you! :smile:

import torch, pyro

def model():
    with pyro.plate('foo', 10):
        pyro.sample('bar', pyro.distributions.Bernoulli(torch.tensor(0.0)))
    return

def guide():
    with pyro.plate('foo', 10, subsample_size=5):
        pyro.sample('bar', pyro.distributions.Bernoulli(torch.tensor(1.0)))
    return

Trace / replay does what I expect:

guide_trace = pyro.poutine.trace(guide).get_trace()
replayed_trace = pyro.poutine.trace(pyro.poutine.replay(model, trace=guide_trace)).get_trace()
print(replayed_trace.format_shapes())
print(replayed_trace.nodes['bar']['value'])
Trace Shapes:    
 Param Sites:    
Sample Sites:    
     foo dist   |
        value 5 |
     bar dist 5 |
        value 5 |
tensor([1., 1., 1., 1., 1.])

Whereas infer.Predictive gives an error (which goes away if there is no subsampling):

print(pyro.infer.Predictive(model, guide=guide, num_samples=3)())
RuntimeError: shape '[3, 10]' is invalid for input of size 15
1 Like