HI, I tried to extend the HMM example (at the end on the tutorial on enumeration) by adding additional independent dimensions (e.g data now has a shape num_steps x new_dim x data_dim). Here is the code
indent preformatted text by 4 spaces
import pyro
from pyro.infer import Trace_ELBO, TraceEnum_ELBO
from pyro.contrib.autoguide import AutoDiagonalNormal
from pyro import poutine
data_dim = 4
num_steps = 10
new_dim = 5
data = dist.Categorical(torch.ones(num_steps, new_dim, data_dim)).sample()
def hmm_model(data, data_dim, new_dim, hidden_dim=10):
with pyro.plate('new_dim', new_dim):
with pyro.plate("hidden_state", hidden_dim):
transition = pyro.sample("transition", dist.Dirichlet(0.5 * torch.ones(hidden_dim)))
emission = pyro.sample("emission", dist.Dirichlet(0.5 * torch.ones(data_dim)))
x = torch.zeros(new_dim, dtype=torch.long) # initial state
for t, y in pyro.markov(enumerate(data)):
x = pyro.sample("x_{}".format(t), dist.Categorical(transition[x, range(new_dim)]),
infer={"enumerate": "parallel"})
pyro.sample("y_{}".format(t), dist.Categorical(emission[x, range(new_dim)]), obs=y)
print("x_{}.shape = {}".format(t, x.shape))
hmm_guide = AutoDiagonalNormal(poutine.block(hmm_model, expose=["transition", "emission"]))
pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=2)
elbo.loss(hmm_model, hmm_guide, data, data_dim=data_dim, new_dim=new_dim);
However I end up with the following error
x_0.shape = torch.Size([5])
x_1.shape = torch.Size([5])
x_2.shape = torch.Size([5])
x_3.shape = torch.Size([5])
x_4.shape = torch.Size([5])
x_5.shape = torch.Size([5])
x_6.shape = torch.Size([5])
x_7.shape = torch.Size([5])
x_8.shape = torch.Size([5])
x_9.shape = torch.Size([5])
x_0.shape = torch.Size([10, 1, 1])
x_1.shape = torch.Size([10, 1, 1, 1])
x_2.shape = torch.Size([10, 1, 1])
x_3.shape = torch.Size([10, 1, 1, 1])
x_4.shape = torch.Size([10, 1, 1])
x_5.shape = torch.Size([10, 1, 1, 1])
x_6.shape = torch.Size([10, 1, 1])
x_7.shape = torch.Size([10, 1, 1, 1])
x_8.shape = torch.Size([10, 1, 1])
x_9.shape = torch.Size([10, 1, 1, 1])
File “”, line 30, in
elbo.loss(hmm_model, hmm_guide, data, data_dim=data_dim, nsub=nsub);
File “/mnt/data/miniconda/envs/pyro-ppl/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py”, line 318, in loss
for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
File “/mnt/data/miniconda/envs/pyro-ppl/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py”, line 308, in _get_traces
yield self._get_trace(model, guide, *args, **kwargs)
File “/mnt/data/miniconda/envs/pyro-ppl/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py”, line 262, in _get_trace
“flat”, self.max_plate_nesting, model, guide, *args, **kwargs)
File “/mnt/data/miniconda/envs/pyro-ppl/lib/python3.7/site-packages/pyro/infer/enum.py”, line 56, in get_importance_trace
check_site_shape(site, max_plate_nesting)
File “/mnt/data/miniconda/envs/pyro-ppl/lib/python3.7/site-packages/pyro/util.py”, line 262, in check_site_shape
‘- .permute() data dimensions’]))
ValueError: at site “x_0”, invalid log_prob shape
Expected , actual [1, 5]
Try one of the following fixes:
- enclose the batched tensor in a with plate(…): context
- .to_event(…) the distribution being sampled
- .permute() data dimensions
I do not quite understand what is going on here and I will highly appreciate an explanation of what I am doing wrong.
Is it the case that one cannot enumerate over multidimensional variables (e.g. x, and y have to be of shape == ())?