Hi Pyro Devs!
I am getting a funny error message about a problem in packing tensors (see below), when I use NUTS sampling with parallel sampling, for a mixture-model based logistic regression .
Do you know why I get this error, and how I can potentially fix it?
Thank you very much in advance!
Minimal Example:
import numpy as np
import pyro
from pyro import distributions as dist
from pyro.infer.mcmc import NUTS, MCMC
import torch
data = torch.tensor([
[1., 2., 3.],
[4., -2., 5.],
[-100., -300., -500.]])
corr = torch.tensor([1., 1., 0.])
def model(data, corr=None):
ws = pyro.sample('ws', dist.MultivariateNormal(torch.zeros(data.size(-1)), torch.eye(data.size(-1))))
err = pyro.sample('err', dist.HalfNormal(1.))
offset_ws = pyro.sample('offset_ws', dist.Dirichlet(torch.tensor([1., 1.])))
offset_means = pyro.sample('offset_means', dist.Normal(torch.zeros(2), torch.ones(2)).expand_by([data.size(-1)]))
with pyro.plate('data', data.size(-2), dim=-2):
offset_ch = pyro.sample('offset_ch', dist.Categorical(offset_ws), infer=dict(enumerate='parallel'))
xs = data - offset_means[torch.arange(data.size(-1)), offset_ch]
if corr is not None:
corr = corr.unsqueeze(-1)
return pyro.sample('corr', dist.Normal(torch.sigmoid(ws @ xs).unsqueeze(-1), err), obs=corr)
Error:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/ops/packed.py in pack(value, dim_to_symbol)
24 dims = ''.join(dim_to_symbol[dim - shift]
---> 25 for dim, size in enumerate(shape)
26 if size > 1)
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/ops/packed.py in <genexpr>(.0)
25 for dim, size in enumerate(shape)
---> 26 if size > 1)
27 except KeyError:
KeyError: -2
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/poutine/trace_struct.py in pack_tensors(self, plate_to_symbol)
318 elif "log_prob" in site:
--> 319 packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
320 packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol)
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/ops/packed.py in pack(value, dim_to_symbol)
31 'Actual shape: {}'.format(tuple(value.shape)),
---> 32 "Try adding shape assertions for your model's sample values and distribution parameters."]))
33 value = value.squeeze()
ValueError: Invalid tensor shape.
Allowed dims:
Actual shape: (3, 2)
Try adding shape assertions for your model's sample values and distribution parameters.
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
<ipython-input-21-f229bd5183d7> in <module>()
3 nuts_kernel = NUTS(model, max_plate_nesting=2)
4 mcmc = MCMC(nuts_kernel, num_samples=10, warmup_steps=5)
----> 5 posterior = mcmc.run(data, corr=corr)
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/abstract_infer.py in run(self, *args, **kwargs)
198 self._reset()
199 with poutine.block():
--> 200 for i, vals in enumerate(self._traces(*args, **kwargs)):
201 if len(vals) == 2:
202 chain_id = 0
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/mcmc/mcmc.py in _traces(self, *args, **kwargs)
262
263 def _traces(self, *args, **kwargs):
--> 264 for sample in self.sampler._traces(*args, **kwargs):
265 yield sample
266
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/mcmc/mcmc.py in _traces(self, *args, **kwargs)
200 progress_bar = initialize_progbar(self.warmup_steps, self.num_samples, disable=self.disable_progbar)
201 self.logger = initialize_logger(self.logger, logger_id, progress_bar, log_queue)
--> 202 self.kernel.setup(self.warmup_steps, *args, **kwargs)
203 trace = self.kernel.initial_trace
204 with optional(progress_bar, not is_multiprocessing):
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
354 self._args = args
355 self._kwargs = kwargs
--> 356 self._initialize_model_properties()
357 self._configure_adaptation()
358
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py in _initialize_model_properties(self)
342 self._trace_prob_evaluator = TraceEinsumEvaluator(trace,
343 self._has_enumerable_sites,
--> 344 self.max_plate_nesting)
345 mass_matrix_size = sum(self._r_numels.values())
346 if self.full_mass:
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/mcmc/util.py in __init__(self, model_trace, has_enumerable_sites, max_plate_nesting)
158 self._enum_dims = set()
159 self.ordering = {}
--> 160 self._populate_cache(model_trace)
161
162 def _populate_cache(self, model_trace):
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/infer/mcmc/util.py in _populate_cache(self, model_trace)
171 "has discrete (enumerable) sites.")
172 model_trace.compute_log_prob()
--> 173 model_trace.pack_tensors()
174 for name, site in model_trace.nodes.items():
175 if site["type"] == "sample" and not isinstance(site["fn"], _Subsample):
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/poutine/trace_struct.py in pack_tensors(self, plate_to_symbol)
325 ValueError("Error while packing tensors at site '{}':\n {}\n{}"
326 .format(site["name"], exc_value, shapes)),
--> 327 traceback)
328
329 def format_shapes(self, title='Trace Shapes:', last_site=None):
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/six.py in reraise(tp, value, tb)
690 value = tp()
691 if value.__traceback__ is not tb:
--> 692 raise value.with_traceback(tb)
693 raise value
694 finally:
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/poutine/trace_struct.py in pack_tensors(self, plate_to_symbol)
317 packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol)
318 elif "log_prob" in site:
--> 319 packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
320 packed["unscaled_log_prob"] = pack(site["unscaled_log_prob"], dim_to_symbol)
321 except ValueError:
~/Documents/SourceControl/probprog-sandbox/venv/lib/python3.7/site-packages/pyro/ops/packed.py in pack(value, dim_to_symbol)
30 'Allowed dims: {}'.format(', '.join(map(str, sorted(dim_to_symbol)))),
31 'Actual shape: {}'.format(tuple(value.shape)),
---> 32 "Try adding shape assertions for your model's sample values and distribution parameters."]))
33 value = value.squeeze()
34 value._pyro_dims = dims
ValueError: Error while packing tensors at site 'offset_means':
Invalid tensor shape.
Allowed dims:
Actual shape: (3, 2)
Try adding shape assertions for your model's sample values and distribution parameters.
Trace Shapes:
Param Sites:
Sample Sites:
ws dist | 3
value | 3
log_prob |
err dist |
value |
log_prob |
offset_ws dist | 2
value | 2
log_prob |
offset_means dist 3 2 |
value 3 2 |
log_prob 3 2 |