The Arviz function `from_pyro` has mismatching dimensions

It seems the samples dimension and observation dimension are mixed up. Pyro version 1.1.0 and Arviz 0.6.1, the code below is from the pyro examples.

import arviz as az
import torch

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import MCMC, NUTS

J = 8
y = torch.tensor([28,  8, -3,  7, -1,  1, 18, 12]).type(torch.Tensor)
sigma = torch.tensor([15, 10, 16, 11,  9, 11, 10, 18]).type(torch.Tensor)

def model(sigma):
    eta = pyro.sample('eta', dist.Normal(torch.zeros(J), torch.ones(J)))
    mu = pyro.sample('mu', dist.Normal(torch.zeros(1), 10 * torch.ones(1)))
    tau = pyro.sample('tau', dist.HalfCauchy(scale=25 * torch.ones(1)))

    theta = mu + tau * eta

    return pyro.sample("obs", dist.Normal(theta, sigma))


def conditioned_model(model, sigma, y):
    return poutine.condition(model, data={"obs": y})(sigma)


nuts_kernel = NUTS(conditioned_model)
mcmc = MCMC(nuts_kernel,
            num_samples=200,
            warmup_steps=100,
            num_chains=4)
mcmc.run(model, sigma, y)

mcmc_az = az.from_pyro(mcmc)

The error message:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    156             try:
--> 157                 ret = self.fn(*args, **kwargs)
    158             except (ValueError, RuntimeError):

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      7     with context:
----> 8         return fn(*args, **kwargs)
      9 

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      7     with context:
----> 8         return fn(*args, **kwargs)
      9 

~/git/pyro/examples/eight_schools/az.py in conditioned_model(model, sigma, y)
    25 def conditioned_model(model, sigma, y):
---> 26     return poutine.condition(model, data={"obs": y})(sigma)
    27 

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      7     with context:
----> 8         return fn(*args, **kwargs)
      9 

~/git/pyro/examples/eight_schools/az.py in model(sigma)
    15 def model(sigma):
---> 16     eta = pyro.sample('eta', dist.Normal(torch.zeros(J), torch.ones(J)))
    17     mu = pyro.sample('mu', dist.Normal(torch.zeros(1), 10 * torch.ones(1)))

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
    109         # apply the stack and return its return value
--> 110         apply_stack(msg)
    111         return msg["value"]

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
    189 
--> 190         frame._process_message(msg)
    191 

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/plate_messenger.py in _process_message(self, msg)
    11         super(PlateMessenger, self)._process_message(msg)
---> 12         return BroadcastMessenger._pyro_sample(msg)
    13 

~/miniconda3/envs/bayes/lib/python3.7/contextlib.py in inner(*args, **kwds)
    73             with self._recreate_cm():
---> 74                 return func(*args, **kwds)
    75         return inner

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/broadcast_messenger.py in _pyro_sample(msg)
    55                     raise ValueError("Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
---> 56                         f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim]))
    57                 target_batch_shape[f.dim] = f.size

ValueError: Shape mismatch inside plate('_num_predictive_samples') at site eta dim -1, 800 vs 8

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
~/git/pyro/examples/eight_schools/az.py in <module>
    36 print(pyro.__version__)
    37 print(az.__version__)
---> 38 mcmc_az = az.from_pyro(mcmc)

~/miniconda3/envs/bayes/lib/python3.7/site-packages/arviz/data/io_pyro.py in from_pyro(posterior, prior, posterior_predictive, coords, dims)
    203         posterior_predictive=posterior_predictive,
    204         coords=coords,
--> 205         dims=dims,
    206     ).to_inference_data()

~/miniconda3/envs/bayes/lib/python3.7/site-packages/arviz/data/io_pyro.py in to_inference_data(self)
    174             **{
    175                 "posterior": self.posterior_to_xarray(),
--> 176                 "sample_stats": self.sample_stats_to_xarray(),
    177                 "posterior_predictive": self.posterior_predictive_to_xarray(),
    178                 **self.priors_to_xarray(),

~/miniconda3/envs/bayes/lib/python3.7/site-packages/arviz/data/base.py in wrapped(cls, *args, **kwargs)
    34                 if all([getattr(cls, prop_i) is None for prop_i in prop]):
    35                     return None
---> 36             return func(cls, *args, **kwargs)
    37 
    38         return wrapped

~/miniconda3/envs/bayes/lib/python3.7/site-packages/arviz/data/io_pyro.py in sample_stats_to_xarray(self)
    93             samples = self.posterior.get_samples(group_by_chain=False)
    94             predictive = self.pyro.infer.Predictive(self.model, samples)
---> 95             obs_site = predictive.get_vectorized_trace(*self._args, **self._kwargs).nodes[obs_name]
    96             log_likelihood = obs_site["fn"].log_prob(obs_site["value"]).detach().cpu().numpy()
    97             if self.dims is not None:

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/infer/predictive.py in get_vectorized_trace(self, *args, **kwargs)
    216                                             parallel=self.parallel, model_args=args, model_kwargs=kwargs)
    217         return _predictive(self.model, posterior_samples, self.num_samples,
--> 218                            return_trace=True, model_args=args, model_kwargs=kwargs)

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/infer/predictive.py in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
    60     if return_trace:
    61         trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\
---> 62             .get_trace(*model_args, **model_kwargs)
    63         return trace
    64 

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    175         Calls this poutine and returns its trace instead of the function's return value.
    176         """
--> 177         self(*args, **kwargs)
    178         return self.msngr.get_trace()

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    159                 exc_type, exc_value, traceback = sys.exc_info()
    160                 shapes = self.msngr.trace.format_shapes()
--> 161                 raise exc_type(u"{}\n{}".format(exc_value, shapes)).with_traceback(traceback)
    162             self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
    163         return ret

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    155                                       args=args, kwargs=kwargs)
    156             try:
--> 157                 ret = self.fn(*args, **kwargs)
    158             except (ValueError, RuntimeError):
    159                 exc_type, exc_value, traceback = sys.exc_info()

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      6 def _context_wrap(context, fn, *args, **kwargs):
      7     with context:
----> 8         return fn(*args, **kwargs)
      9 
    10 

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      6 def _context_wrap(context, fn, *args, **kwargs):
      7     with context:
----> 8         return fn(*args, **kwargs)
      9 
    10 

~/git/pyro/examples/eight_schools/az.py in conditioned_model(model, sigma, y)
    24 
    25 def conditioned_model(model, sigma, y):
---> 26     return poutine.condition(model, data={"obs": y})(sigma)
    27 
    28 

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      6 def _context_wrap(context, fn, *args, **kwargs):
      7     with context:
----> 8         return fn(*args, **kwargs)
      9 
    10 

~/git/pyro/examples/eight_schools/az.py in model(sigma)
    14 
    15 def model(sigma):
---> 16     eta = pyro.sample('eta', dist.Normal(torch.zeros(J), torch.ones(J)))
    17     mu = pyro.sample('mu', dist.Normal(torch.zeros(1), 10 * torch.ones(1)))
    18     tau = pyro.sample('tau', dist.HalfCauchy(scale=25 * torch.ones(1)))

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
    108             msg["is_observed"] = True
    109         # apply the stack and return its return value
--> 110         apply_stack(msg)
    111         return msg["value"]
    112 

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
    188         pointer = pointer + 1
    189 
--> 190         frame._process_message(msg)
    191 
    192         if msg["stop"]:

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/plate_messenger.py in _process_message(self, msg)
    10     def _process_message(self, msg):
    11         super(PlateMessenger, self)._process_message(msg)
---> 12         return BroadcastMessenger._pyro_sample(msg)
    13 
    14     def __enter__(self):

~/miniconda3/envs/bayes/lib/python3.7/contextlib.py in inner(*args, **kwds)
    72         def inner(*args, **kwds):
    73             with self._recreate_cm():
---> 74                 return func(*args, **kwds)
    75         return inner
    76 

~/miniconda3/envs/bayes/lib/python3.7/site-packages/pyro/poutine/broadcast_messenger.py in _pyro_sample(msg)
    54                 if target_batch_shape[f.dim] is not None and target_batch_shape[f.dim] != f.size:
    55                     raise ValueError("Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
---> 56                         f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim]))
    57                 target_batch_shape[f.dim] = f.size
    58             # Starting from the right, if expected size is None at an index,

ValueError: Shape mismatch inside plate('_num_predictive_samples') at site eta dim -1, 800 vs 8
Trace Shapes:
Param Sites:
Sample Sites:

By the way, is it possible to have more user friendly error messages? In Pyro the error messages are usually several pages long and invovle all the inner workings of Pyro…

Hi @olivierma, in Pyro, declaring tensor shapes is important. For example,

def model(sigma):
    mu = pyro.sample('mu', dist.Normal(0, 10)
    tau = pyro.sample('tau', dist.HalfCauchy(scale=25))

    with pyro.plate("J"):
        eta = pyro.sample('eta', dist.Normal(torch.zeros(J), torch.ones(J)))
        theta = mu + tau * eta
        return pyro.sample("obs", dist.Normal(theta, sigma))

So eta and obs should be used with plate notation.

I think rather than throwing errors, arviz can just ignore log_likelihood computations if it can’t vectorize your model / or compute log_likelihood sequentially. I’ll provide a fix for it in arviz.