Weird error about "packing tensors" when sampling with enumeration

Hi Pyro Devs!

I am getting a funny error message about a problem in packing tensors (see below), when I use NUTS sampling with parallel sampling, for a mixture-model based logistic regression :smile:.

Do you know why I get this error, and how I can potentially fix it?

Thank you very much in advance!

Minimal Example:

import numpy as np
import pyro
from pyro import distributions as dist
from pyro.infer.mcmc import NUTS, MCMC
import torch

data = torch.tensor([
    [1., 2., 3.],
    [4., -2., 5.],
    [-100., -300., -500.]])
corr = torch.tensor([1., 1., 0.])

def model(data, corr=None):
    ws = pyro.sample('ws', dist.MultivariateNormal(torch.zeros(data.size(-1)), torch.eye(data.size(-1))))
    err = pyro.sample('err', dist.HalfNormal(1.))
    offset_ws = pyro.sample('offset_ws', dist.Dirichlet(torch.tensor([1., 1.])))
    offset_means = pyro.sample('offset_means', dist.Normal(torch.zeros(2), torch.ones(2)).expand_by([data.size(-1)]))
    with pyro.plate('data', data.size(-2), dim=-2):
        offset_ch = pyro.sample('offset_ch', dist.Categorical(offset_ws), infer=dict(enumerate='parallel'))
        xs = data - offset_means[torch.arange(data.size(-1)), offset_ch]
        if corr is not None:
            corr = corr.unsqueeze(-1)
        return pyro.sample('corr', dist.Normal(torch.sigmoid(ws @ xs).unsqueeze(-1), err), obs=corr)

Error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/ops/packed.py in pack(value, dim_to_symbol)
     24                 dims = ''.join(dim_to_symbol[dim - shift]
---> 25                                for dim, size in enumerate(shape)
     26                                if size > 1)

~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/ops/packed.py in <genexpr>(.0)
     25                                for dim, size in enumerate(shape)
---> 26                                if size > 1)
     27         except KeyError:

KeyError: -2

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/poutine/trace_struct.py in pack_tensors(self, plate_to_symbol)
    318                 elif "log_prob" in site:
--> 319                     packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
    320                     packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol)

~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/ops/packed.py in pack(value, dim_to_symbol)
     31                 'Actual shape: {}'.format(tuple(value.shape)),
---> 32                 "Try adding shape assertions for your model's sample values and distribution parameters."]))
     33         value = value.squeeze()

ValueError: Invalid tensor shape.
  Allowed dims: 
  Actual shape: (3, 2)
  Try adding shape assertions for your model's sample values and distribution parameters.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-21-f229bd5183d7> in <module>()
      3 nuts_kernel = NUTS(model, max_plate_nesting=2)
      4 mcmc = MCMC(nuts_kernel, num_samples=10, warmup_steps=5)
----> 5 posterior = mcmc.run(data, corr=corr)

~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/abstract_infer.py in run(self, *args, **kwargs)
    198         self._reset()
    199         with poutine.block():
--> 200             for i, vals in enumerate(self._traces(*args, **kwargs)):
    201                 if len(vals) == 2:
    202                     chain_id = 0

~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/mcmc/mcmc.py in _traces(self, *args, **kwargs)
    262 
    263     def _traces(self, *args, **kwargs):
--> 264         for sample in self.sampler._traces(*args, **kwargs):
    265             yield sample
    266 

~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/mcmc/mcmc.py in _traces(self, *args, **kwargs)
    200             progress_bar = initialize_progbar(self.warmup_steps, self.num_samples, disable=self.disable_progbar)
    201         self.logger = initialize_logger(self.logger, logger_id, progress_bar, log_queue)
--> 202         self.kernel.setup(self.warmup_steps, *args, **kwargs)
    203         trace = self.kernel.initial_trace
    204         with optional(progress_bar, not is_multiprocessing):

~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
    354         self._args = args
    355         self._kwargs = kwargs
--> 356         self._initialize_model_properties()
    357         self._configure_adaptation()
    358 

~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py in _initialize_model_properties(self)
    342         self._trace_prob_evaluator = TraceEinsumEvaluator(trace,
    343                                                           self._has_enumerable_sites,
--> 344                                                           self.max_plate_nesting)
    345         mass_matrix_size = sum(self._r_numels.values())
    346         if self.full_mass:

~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/mcmc/util.py in __init__(self, model_trace, has_enumerable_sites, max_plate_nesting)
    158         self._enum_dims = set()
    159         self.ordering = {}
--> 160         self._populate_cache(model_trace)
    161 
    162     def _populate_cache(self, model_trace):

~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/mcmc/util.py in _populate_cache(self, model_trace)
    171                              "has discrete (enumerable) sites.")
    172         model_trace.compute_log_prob()
--> 173         model_trace.pack_tensors()
    174         for name, site in model_trace.nodes.items():
    175             if site["type"] == "sample" and not isinstance(site["fn"], _Subsample):

~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/poutine/trace_struct.py in pack_tensors(self, plate_to_symbol)
    325                             ValueError("Error while packing tensors at site '{}':\n  {}\n{}"
    326                                        .format(site["name"], exc_value, shapes)),
--> 327                             traceback)
    328 
    329     def format_shapes(self, title='Trace Shapes:', last_site=None):

~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/six.py in reraise(tp, value, tb)
    690                 value = tp()
    691             if value.__traceback__ is not tb:
--> 692                 raise value.with_traceback(tb)
    693             raise value
    694         finally:

~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/poutine/trace_struct.py in pack_tensors(self, plate_to_symbol)
    317                     packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol)
    318                 elif "log_prob" in site:
--> 319                     packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
    320                     packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol)
    321             except ValueError:

~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/ops/packed.py in pack(value, dim_to_symbol)
     30                 'Allowed dims: {}'.format(', '.join(map(str, sorted(dim_to_symbol)))),
     31                 'Actual shape: {}'.format(tuple(value.shape)),
---> 32                 "Try adding shape assertions for your model's sample values and distribution parameters."]))
     33         value = value.squeeze()
     34         value._pyro_dims = dims

ValueError: Error while packing tensors at site 'offset_means':
  Invalid tensor shape.
  Allowed dims: 
  Actual shape: (3, 2)
  Try adding shape assertions for your model's sample values and distribution parameters.
    Trace Shapes:        
     Param Sites:        
    Sample Sites:        
          ws dist     | 3
            value     | 3
         log_prob     |  
         err dist     |  
            value     |  
         log_prob     |  
   offset_ws dist     | 2
            value     | 2
         log_prob     |  
offset_means dist 3 2 |  
            value 3 2 |  
         log_prob 3 2 |  

Hi @ahmadsalim it looks like you should be calling .to_event(1) on your offset_means distribution:

offset_means = pyro.sample('offset_means',
                           dist.Normal(torch.zeros(2), torch.ones(2))
                               .expand_by([data.size(-1)])
                               .to_event(1))

You can see the shape error in the shape diagram: what you see is

offset_means dist 3 2 |  
            value 3 2 |  
         log_prob 3 2 |

but adding .to_event(1) will shift the shape one column to the right:

offset_means dist 3 | 2  
            value 3 | 2  
         log_prob 3 | 2
2 Likes

Awesome, thank you very much! :smile: