Extending HMM example with additional independent dimension

HI, I tried to extend the HMM example (at the end on the tutorial on enumeration) by adding additional independent dimensions (e.g data now has a shape num_steps x new_dim x data_dim). Here is the code
indent preformatted text by 4 spaces
import pyro
from pyro.infer import Trace_ELBO, TraceEnum_ELBO
from pyro.contrib.autoguide import AutoDiagonalNormal
from pyro import poutine

data_dim = 4
num_steps = 10
new_dim = 5
data = dist.Categorical(torch.ones(num_steps, new_dim, data_dim)).sample()

def hmm_model(data, data_dim, new_dim, hidden_dim=10):
    with pyro.plate('new_dim', new_dim):
        with pyro.plate("hidden_state", hidden_dim):
            transition = pyro.sample("transition", dist.Dirichlet(0.5 * torch.ones(hidden_dim)))
            emission = pyro.sample("emission", dist.Dirichlet(0.5 * torch.ones(data_dim)))
    
    x = torch.zeros(new_dim, dtype=torch.long)  # initial state
    for t, y in pyro.markov(enumerate(data)):
        x = pyro.sample("x_{}".format(t), dist.Categorical(transition[x, range(new_dim)]),
                         infer={"enumerate": "parallel"})
        pyro.sample("y_{}".format(t), dist.Categorical(emission[x, range(new_dim)]), obs=y)
         print("x_{}.shape = {}".format(t, x.shape))

hmm_guide = AutoDiagonalNormal(poutine.block(hmm_model, expose=["transition", "emission"]))
pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=2)
elbo.loss(hmm_model, hmm_guide, data, data_dim=data_dim, new_dim=new_dim);

However I end up with the following error

x_0.shape = torch.Size([5])
x_1.shape = torch.Size([5])
x_2.shape = torch.Size([5])
x_3.shape = torch.Size([5])
x_4.shape = torch.Size([5])
x_5.shape = torch.Size([5])
x_6.shape = torch.Size([5])
x_7.shape = torch.Size([5])
x_8.shape = torch.Size([5])
x_9.shape = torch.Size([5])
x_0.shape = torch.Size([10, 1, 1])
x_1.shape = torch.Size([10, 1, 1, 1])
x_2.shape = torch.Size([10, 1, 1])
x_3.shape = torch.Size([10, 1, 1, 1])
x_4.shape = torch.Size([10, 1, 1])
x_5.shape = torch.Size([10, 1, 1, 1])
x_6.shape = torch.Size([10, 1, 1])
x_7.shape = torch.Size([10, 1, 1, 1])
x_8.shape = torch.Size([10, 1, 1])
x_9.shape = torch.Size([10, 1, 1, 1])

File “”, line 30, in
elbo.loss(hmm_model, hmm_guide, data, data_dim=data_dim, nsub=nsub);

File “/mnt/data/miniconda/envs/pyro-ppl/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py”, line 318, in loss
for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):

File “/mnt/data/miniconda/envs/pyro-ppl/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py”, line 308, in _get_traces
yield self._get_trace(model, guide, *args, **kwargs)

File “/mnt/data/miniconda/envs/pyro-ppl/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py”, line 262, in _get_trace
“flat”, self.max_plate_nesting, model, guide, *args, **kwargs)

File “/mnt/data/miniconda/envs/pyro-ppl/lib/python3.7/site-packages/pyro/infer/enum.py”, line 56, in get_importance_trace
check_site_shape(site, max_plate_nesting)

File “/mnt/data/miniconda/envs/pyro-ppl/lib/python3.7/site-packages/pyro/util.py”, line 262, in check_site_shape
‘- .permute() data dimensions’]))

ValueError: at site “x_0”, invalid log_prob shape
Expected , actual [1, 5]
Try one of the following fixes:

  • enclose the batched tensor in a with plate(…): context
  • .to_event(…) the distribution being sampled
  • .permute() data dimensions

I do not quite understand what is going on here and I will highly appreciate an explanation of what I am doing wrong.
Is it the case that one cannot enumerate over multidimensional variables (e.g. x, and y have to be of shape == ())?

I found that adding a plate resolves the ValueError

def hmm_model(data, data_dim, new_dim, hidden_dim=10):
    with pyro.plate('sub', new_dim):
        with pyro.plate("hidden_state", hidden_dim):
            transition = pyro.sample("transition", dist.Dirichlet(0.5 * torch.ones(hidden_dim)))
            emission = pyro.sample("emission", dist.Dirichlet(0.5 * torch.ones(data_dim)))

    x = [0]*new_dim  # initial state
    for t, y in pyro.markov(enumerate(data)):
        with pyro.plate('new_dim_{}'.format(t), new_dim) as n:
            x = pyro.sample("x_{}".format(t), dist.Categorical(transition[x, n]),
                    infer={"enumerate": "parallel"})
            pyro.sample("y_{}".format(t), dist.Categorical(emission[x, n]), obs=y[n])
            print("x_{}.shape = {}".format(t, x.shape))

Still, I am a bit uncertain if this is a correct solution as now x depends on x within a plate. Would this imply that the conditional independence is broken?