Hi,

Assume that I have 45 subjects, each with 68 features. I want to assign each subjects to K =3 clusters by creating a categorical variable called *assignments* whose size equals to the number of subjects: 45.

In my code, xt_collection is a matrix with size [45, 68, 3] = [nsubs, nfeatures, class K]. It means for each of the 45 subjects with 68 features, there are three possible assignments. When there’s no enumeration, `xt_collection[np.arange(45),:,assignments]`

can lead to an outcome matrix with size [45, 68], which means for subject i = 0,…,44, pick the corresponding 68 elements from a `[68,3]`

matrix according to the `[68,assignments[i]]`

. So in the toy example, I would expect that the estimation of assignments equal to torch.zeros(45) since my actual observations are the same as column 0 of the xt_collection matrix. (All the subjects in this case come from cluster 0)

Without enumeration, the sampling works fine. However, when I use enumeration, there’s an error in dimention. One possible solution I came out is to replace the nested pyro.plate with a nested for loop, however, this leads to a very slow inference.

Could you please give me any suggestions?

Thank you so much for your help!

```
import torch
import os
import pyro
import numpy as np
from pyro.optim import Adam
from pyro.infer.autoguide import AutoNormal,init_to_value
from pyro.ops.indexing import Vindex
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, TraceEnum_ELBO,TraceGraph_ELBO, config_enumerate, infer_discrete
from pyro.ops.special import safe_log
from pyro.ops.tensor_utils import convolve
from pyro.util import warn_if_nan
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.1')
x0_obs = 0.01* torch.ones([45,68])
x1_obs = torch.ones([45,68])
x2_obs = 10*torch.ones([45,68])
@config_enumerate
def model(xobs):
K = 3 # number of class to be signed
sigma = pyro.sample("sigma", dist.HalfNormal(1))
weights = pyro.sample('weights', dist.Dirichlet(torch.ones(K)/K))
with pyro.plate('nsubs', 45):
assignments = pyro.sample("assignments", dist.Categorical(weights))
xt_collection = torch.stack([x0_obs,x1_obs,x2_obs],axis = 2) # torch.Size([45, 68, 3]) = [nsubs, nfeatures, class K]
xt_keep = xt_collection[torch.arange(45),:,assignments]
with pyro.plate('Nfeatures',68):
with pyro.plate('Nsubs',45):
pyro.sample('obs', dist.Normal(xt_keep, sigma), obs= torch.tensor(x0_obs))
```

```
global global_guide, svi
pyro.set_rng_seed(31)
pyro.clear_param_store()
optim = pyro.optim.Adam({'lr': 0.01})
global_guide = AutoNormal(poutine.block(model, hide=['assignments']))
elbo = TraceEnum_ELBO()
svi = SVI(model, global_guide, optim, loss=elbo)
losses = []
for i in range(1001 if not smoke_test else 2):
loss = svi.step(x0_obs)
losses.append(loss)
if i % 10 == 0:
print("ELBO at iter i = "+str(i),loss)
```

```
KeyError: -4
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
ValueError: Invalid tensor shape.
Allowed dims: -3, -2, -1
Actual shape: (3, 1, 45, 68)
Try adding shape assertions for your model's sample values and distribution parameters.
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/pyro/ops/packed.py in pack(value, dim_to_symbol)
42 ]
43 )
---> 44 ) from e
45 value = value.squeeze()
46 value._pyro_dims = dims
ValueError: Error while packing tensors at site 'obs':
Invalid tensor shape.
Allowed dims: -3, -2, -1
Actual shape: (3, 1, 45, 68)
Try adding shape assertions for your model's sample values and distribution parameters.
Trace Shapes:
Param Sites:
Sample Sites:
sigma dist |
value |
log_prob |
weights dist | 3
value | 3
log_prob |
assignments dist 45 |
value 3 1 1 |
log_prob 3 1 45 |
obs dist 3 1 45 68 |
value 45 68 |
log_prob 3 1 45 68 |
```