Error in `Predictive` with conditioned model

I’m trying to use infer.Predictive with a very simple model similar to the one in intro_part_ii.

Set up the model and perform inference:

In [1]: import pyro
   ...: import torch

In [2]: def model():
   ...:     a = pyro.sample('a', pyro.distributions.Normal(0, 100))
   ...:     return pyro.sample('b', pyro.distributions.Normal(a, 1))
   ...:

In [3]: conditioned_b = pyro.poutine.condition(model, data={'b': 50.0})

In [4]: def guide():
   ...:     loc = pyro.param("loc", torch.tensor(0.0))
   ...:     scale = pyro.param("scale", torch.tensor(1.1))
   ...:     pyro.sample("a", pyro.distributions.Normal(loc, scale))
   ...:

In [5]: from pyro.infer import SVI, Trace_ELBO
   ...: import pyro.optim as optim
   ...:
   ...: pyro.clear_param_store()
   ...:
   ...: svi = SVI(conditioned_b,
   ...:           guide,
   ...:           optim.Adam({"lr": .05}),
   ...:           loss=Trace_ELBO())

In [6]: num_steps = 2500
   ...: for t in range(num_steps):
   ...:     svi.step()
   ...:

Then I can get samples from the posterior like this:

In [7]: from pyro.infer import Predictive
   ...:
   ...: Predictive(model, guide=guide, num_samples=10)()
Out[7]:
{'a': tensor([49.1881, 50.0768, 50.8119, 50.4589, 48.5206, 48.7229, 50.0590, 48.4745,
         49.9926, 49.1318], grad_fn=<AsStridedBackward>),
 'b': tensor([50.5658, 48.7698, 52.9576, 49.2227, 49.4116, 46.4217, 50.6092, 47.5770,
         49.3628, 49.5913], grad_fn=<AsStridedBackward>)}

That seems reasonable.

I would expect the same result (without any samples for b) if I use the conditioned model (the one I actually used for inference).

Instead, that’s an error:

In [8]: Predictive(conditioned_b, guide=guide, num_samples=10)()
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-8-de39137ec330> in <module>
----> 1 Predictive(conditioned_b, guide=guide, num_samples=10)()

~/.virtualenvs/pyro/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    539             result = self._slow_forward(*input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():
    543             hook_result = hook(self, input, result)

~/.virtualenvs/pyro/lib/python3.6/site-packages/pyro/infer/predictive.py in forward(self, *args, **kwargs)
    196                                             parallel=self.parallel, model_args=args, model_kwargs=kwargs)
    197         return _predictive(self.model, posterior_samples, self.num_samples, return_sites=return_sites,
--> 198                            parallel=self.parallel, model_args=args, model_kwargs=kwargs)
    199
    200     def get_samples(self, *args, **kwargs):

~/.virtualenvs/pyro/lib/python3.6/site-packages/pyro/infer/predictive.py in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
     66     for site in model_trace.stochastic_nodes + model_trace.observation_nodes:
     67         append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape)
---> 68         site_shape = (num_samples,) + (1,) * append_ndim + model_trace.nodes[site]['value'].shape
     69         # non-empty return-sites
     70         if return_sites:

AttributeError: 'float' object has no attribute 'shape'

Can anyone explain what I’m missing?

Thanks!

Hi @DavidC, I think you can fix the issue by using PyTorch tensor torch.tensor(50.0) instead of the Python scalar 50.0.

Oops. Thanks!