I’m trying to use infer.Predictive
with a very simple model similar to the one in intro_part_ii.
Set up the model and perform inference:
In [1]: import pyro
...: import torch
In [2]: def model():
...: a = pyro.sample('a', pyro.distributions.Normal(0, 100))
...: return pyro.sample('b', pyro.distributions.Normal(a, 1))
...:
In [3]: conditioned_b = pyro.poutine.condition(model, data={'b': 50.0})
In [4]: def guide():
...: loc = pyro.param("loc", torch.tensor(0.0))
...: scale = pyro.param("scale", torch.tensor(1.1))
...: pyro.sample("a", pyro.distributions.Normal(loc, scale))
...:
In [5]: from pyro.infer import SVI, Trace_ELBO
...: import pyro.optim as optim
...:
...: pyro.clear_param_store()
...:
...: svi = SVI(conditioned_b,
...: guide,
...: optim.Adam({"lr": .05}),
...: loss=Trace_ELBO())
In [6]: num_steps = 2500
...: for t in range(num_steps):
...: svi.step()
...:
Then I can get samples from the posterior like this:
In [7]: from pyro.infer import Predictive
...:
...: Predictive(model, guide=guide, num_samples=10)()
Out[7]:
{'a': tensor([49.1881, 50.0768, 50.8119, 50.4589, 48.5206, 48.7229, 50.0590, 48.4745,
49.9926, 49.1318], grad_fn=<AsStridedBackward>),
'b': tensor([50.5658, 48.7698, 52.9576, 49.2227, 49.4116, 46.4217, 50.6092, 47.5770,
49.3628, 49.5913], grad_fn=<AsStridedBackward>)}
That seems reasonable.
I would expect the same result (without any samples for b
) if I use the conditioned model (the one I actually used for inference).
Instead, that’s an error:
In [8]: Predictive(conditioned_b, guide=guide, num_samples=10)()
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-8-de39137ec330> in <module>
----> 1 Predictive(conditioned_b, guide=guide, num_samples=10)()
~/.virtualenvs/pyro/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)
~/.virtualenvs/pyro/lib/python3.6/site-packages/pyro/infer/predictive.py in forward(self, *args, **kwargs)
196 parallel=self.parallel, model_args=args, model_kwargs=kwargs)
197 return _predictive(self.model, posterior_samples, self.num_samples, return_sites=return_sites,
--> 198 parallel=self.parallel, model_args=args, model_kwargs=kwargs)
199
200 def get_samples(self, *args, **kwargs):
~/.virtualenvs/pyro/lib/python3.6/site-packages/pyro/infer/predictive.py in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
66 for site in model_trace.stochastic_nodes + model_trace.observation_nodes:
67 append_ndim = max_plate_nesting - len(model_trace.nodes[site]["fn"].batch_shape)
---> 68 site_shape = (num_samples,) + (1,) * append_ndim + model_trace.nodes[site]['value'].shape
69 # non-empty return-sites
70 if return_sites:
AttributeError: 'float' object has no attribute 'shape'
Can anyone explain what I’m missing?
Thanks!