It seems the samples dimension and observation dimension are mixed up. Pyro version 1.1.0 and Arviz 0.6.1, the code below is from the pyro examples.
import arviz as az
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import MCMC, NUTS
J = 8
y = torch.tensor([28, 8, -3, 7, -1, 1, 18, 12]).type(torch.Tensor)
sigma = torch.tensor([15, 10, 16, 11, 9, 11, 10, 18]).type(torch.Tensor)
def model(sigma):
eta = pyro.sample('eta', dist.Normal(torch.zeros(J), torch.ones(J)))
mu = pyro.sample('mu', dist.Normal(torch.zeros(1), 10 * torch.ones(1)))
tau = pyro.sample('tau', dist.HalfCauchy(scale=25 * torch.ones(1)))
theta = mu + tau * eta
return pyro.sample("obs", dist.Normal(theta, sigma))
def conditioned_model(model, sigma, y):
return poutine.condition(model, data={"obs": y})(sigma)
nuts_kernel = NUTS(conditioned_model)
mcmc = MCMC(nuts_kernel,
num_samples=200,
warmup_steps=100,
num_chains=4)
mcmc.run(model, sigma, y)
mcmc_az = az.from_pyro(mcmc)
The error message:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
156 try:
--> 157 ret = self.fn(*args, **kwargs)
158 except (ValueError, RuntimeError):
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
7 with context:
----> 8 return fn(*args, **kwargs)
9
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
7 with context:
----> 8 return fn(*args, **kwargs)
9
~/git/pyro/examples/eight_schools/az.py in conditioned_model(model, sigma, y)
25 def conditioned_model(model, sigma, y):
---> 26 return poutine.condition(model, data={"obs": y})(sigma)
27
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
7 with context:
----> 8 return fn(*args, **kwargs)
9
~/git/pyro/examples/eight_schools/az.py in model(sigma)
15 def model(sigma):
---> 16 eta = pyro.sample('eta', dist.Normal(torch.zeros(J), torch.ones(J)))
17 mu = pyro.sample('mu', dist.Normal(torch.zeros(1), 10 * torch.ones(1)))
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
109 # apply the stack and return its return value
--> 110 apply_stack(msg)
111 return msg["value"]
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
189
--> 190 frame._process_message(msg)
191
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/plate_messenger.py in _process_message(self, msg)
11 super(PlateMessenger, self)._process_message(msg)
---> 12 return BroadcastMessenger._pyro_sample(msg)
13
~/miniconda3/envs/bayes/lib/python3.7/contextlib.py in inner(*args, **kwds)
73 with self._recreate_cm():
---> 74 return func(*args, **kwds)
75 return inner
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/broadcast_messenger.py in _pyro_sample(msg)
55 raise ValueError("Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
---> 56 f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim]))
57 target_batch_shape[f.dim] = f.size
ValueError: Shape mismatch inside plate('_num_predictive_samples') at site eta dim -1, 800 vs 8
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
~/git/pyro/examples/eight_schools/az.py in <module>
36 print(pyro.__version__)
37 print(az.__version__)
---> 38 mcmc_az = az.from_pyro(mcmc)
~/miniconda3/envs/bayes/lib/python3.7/site-packages/arviz/data/io_pyro.py in from_pyro(posterior, prior, posterior_predictive, coords, dims)
203 posterior_predictive=posterior_predictive,
204 coords=coords,
--> 205 dims=dims,
206 ).to_inference_data()
~/miniconda3/envs/bayes/lib/python3.7/site-packages/arviz/data/io_pyro.py in to_inference_data(self)
174 **{
175 "posterior": self.posterior_to_xarray(),
--> 176 "sample_stats": self.sample_stats_to_xarray(),
177 "posterior_predictive": self.posterior_predictive_to_xarray(),
178 **self.priors_to_xarray(),
~/miniconda3/envs/bayes/lib/python3.7/site-packages/arviz/data/base.py in wrapped(cls, *args, **kwargs)
34 if all([getattr(cls, prop_i) is None for prop_i in prop]):
35 return None
---> 36 return func(cls, *args, **kwargs)
37
38 return wrapped
~/miniconda3/envs/bayes/lib/python3.7/site-packages/arviz/data/io_pyro.py in sample_stats_to_xarray(self)
93 samples = self.posterior.get_samples(group_by_chain=False)
94 predictive = self.pyro.infer.Predictive(self.model, samples)
---> 95 obs_site = predictive.get_vectorized_trace(*self._args, **self._kwargs).nodes[obs_name]
96 log_likelihood = obs_site["fn"].log_prob(obs_site["value"]).detach().cpu().numpy()
97 if self.dims is not None:
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/infer/predictive.py in get_vectorized_trace(self, *args, **kwargs)
216 parallel=self.parallel, model_args=args, model_kwargs=kwargs)
217 return _predictive(self.model, posterior_samples, self.num_samples,
--> 218 return_trace=True, model_args=args, model_kwargs=kwargs)
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/infer/predictive.py in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
60 if return_trace:
61 trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\
---> 62 .get_trace(*model_args, **model_kwargs)
63 return trace
64
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
175 Calls this poutine and returns its trace instead of the function's return value.
176 """
--> 177 self(*args, **kwargs)
178 return self.msngr.get_trace()
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
159 exc_type, exc_value, traceback = sys.exc_info()
160 shapes = self.msngr.trace.format_shapes()
--> 161 raise exc_type(u"{}\n{}".format(exc_value, shapes)).with_traceback(traceback)
162 self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
163 return ret
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
155 args=args, kwargs=kwargs)
156 try:
--> 157 ret = self.fn(*args, **kwargs)
158 except (ValueError, RuntimeError):
159 exc_type, exc_value, traceback = sys.exc_info()
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
6 def _context_wrap(context, fn, *args, **kwargs):
7 with context:
----> 8 return fn(*args, **kwargs)
9
10
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
6 def _context_wrap(context, fn, *args, **kwargs):
7 with context:
----> 8 return fn(*args, **kwargs)
9
10
~/git/pyro/examples/eight_schools/az.py in conditioned_model(model, sigma, y)
24
25 def conditioned_model(model, sigma, y):
---> 26 return poutine.condition(model, data={"obs": y})(sigma)
27
28
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
6 def _context_wrap(context, fn, *args, **kwargs):
7 with context:
----> 8 return fn(*args, **kwargs)
9
10
~/git/pyro/examples/eight_schools/az.py in model(sigma)
14
15 def model(sigma):
---> 16 eta = pyro.sample('eta', dist.Normal(torch.zeros(J), torch.ones(J)))
17 mu = pyro.sample('mu', dist.Normal(torch.zeros(1), 10 * torch.ones(1)))
18 tau = pyro.sample('tau', dist.HalfCauchy(scale=25 * torch.ones(1)))
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
108 msg["is_observed"] = True
109 # apply the stack and return its return value
--> 110 apply_stack(msg)
111 return msg["value"]
112
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
188 pointer = pointer + 1
189
--> 190 frame._process_message(msg)
191
192 if msg["stop"]:
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/plate_messenger.py in _process_message(self, msg)
10 def _process_message(self, msg):
11 super(PlateMessenger, self)._process_message(msg)
---> 12 return BroadcastMessenger._pyro_sample(msg)
13
14 def __enter__(self):
~/miniconda3/envs/bayes/lib/python3.7/contextlib.py in inner(*args, **kwds)
72 def inner(*args, **kwds):
73 with self._recreate_cm():
---> 74 return func(*args, **kwds)
75 return inner
76
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/broadcast_messenger.py in _pyro_sample(msg)
54 if target_batch_shape[f.dim] is not None and target_batch_shape[f.dim] != f.size:
55 raise ValueError("Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
---> 56 f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim]))
57 target_batch_shape[f.dim] = f.size
58 # Starting from the right, if expected size is None at an index,
ValueError: Shape mismatch inside plate('_num_predictive_samples') at site eta dim -1, 800 vs 8
Trace Shapes:
Param Sites:
Sample Sites:
By the way, is it possible to have more user friendly error messages? In Pyro the error messages are usually several pages long and invovle all the inner workings of Pyro…