Dear Pyro developers,

I have been trying to use discrete site enumeration, for a model with multiple interactions between discrete variables.

When running SVI on the model, I get an error about invalid index (see below) internally in the log_probability calculation of the Categorical distribution. I am wondering whether this is because I did some misspecification in my model, or perhaps I am using some wrong assumption about how discrete site enumeration should work.

I have tried debugging for some time, but I could not find an immediate solution.

Could you kindly assist me with some suggestions on how to fix this when you have time?

Thank you very much in advance!

Debugging information:

I am running the latest pytorch-1.0 preview and on the latest commit at the pytorch-1.0 branch in Pyro.

Relevant model/guide code:

```
def mg_gen(is_guide=False):
def mg(data):
clust, corr, vals, lengths = data
clust = clust.float()
vals = vals.t().float()
corr = corr.float()
n_clusts = clust.size(1)
n_corr = 2
n_class = 4
# Pseudo counts for each cluster and correctness choice, specifying Dirichlet arguments
# for each class of output
pseudocounts = clust.new_ones(n_clusts, n_corr, n_class).float() * 0.5
if is_guide:
pseudocounts = pyro.param('pc_q', pseudocounts)
# Dirichlet prior over output classes (0-3)
obs_dir = pyro.sample('obs_dir', dist.Dirichlet(pseudocounts).independent(2))
feature_plate = pyro.plate('feature', vals.size(0), dim=-2)
with pyro.plate('data', vals.size(1), dim=-1) as idx:
c0corr = vals.new_tensor(4.0).float()
c1corr = vals.new_tensor(1.0).float()
if is_guide:
c0corr = pyro.param('c0corr', c0corr, constraint=constraints.greater_than(1e-3))
c1corr = pyro.param('c1corr', c1corr, constraint=constraints.greater_than(1e-3))
# Beta prior for classes that are correct
corrpr = pyro.sample('corrpr', dist.Beta(concentration0=c0corr, concentration1=c1corr))
corrobs = dict(obs=corr[idx]) if not is_guide else dict()
# Labelled Bernoulli flip coin to determine whether output is "correct" or "incorrect" according to our specification
corrch = pyro.sample('corrch', dist.Bernoulli(corrpr),
infer=dict(enumerate='parallel', is_auxilliary=True), **corrobs).long()
# Choice of cluster for data, given prior distribution over relevant clustering
clustch = pyro.sample('clustch', dist.Categorical(clust.index_select(0, idx)),
infer=dict(enumerate='parallel'))
with feature_plate as fidx:
# Observed values given choice of cluster and possible correctness
if not is_guide:
pyro.sample('obs', dist.Categorical(obs_dir[clustch, corrch]), obs=vals)
return mg
```

Error:

```
Traceback (most recent call last):
File "/home/asal/src/project/pyro/poutine/trace_struct.py", line 149, in compute_log_prob
site["log_prob"]
KeyError: 'log_prob'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "pyro_debug_enum.py", line 128, in <module>
main(sys.argv)
File "pyro_debug_enum.py", line 124, in main
infer(dataloader)
File "pyro_debug_enum.py", line 75, in infer
total_epoch_loss_train = _train(svi, dataloader)
File "pyro_debug_enum.py", line 62, in _train
epoch_loss += svi.step(data)
File "/home/asal/src/project/pyro/infer/svi.py", line 96, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "/home/asal/src/project/pyro/infer/traceenum_elbo.py", line 336, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
File "/home/asal/src/project/pyro/infer/traceenum_elbo.py", line 279, in _get_traces
yield self._get_trace(model, guide, *args, **kwargs)
File "/home/asal/src/project/pyro/infer/traceenum_elbo.py", line 233, in _get_trace
"flat", self.max_plate_nesting, model, guide, *args, **kwargs)
File "/home/asal/src/project/pyro/infer/enum.py", line 51, in get_importance_trace
model_trace.compute_log_prob()
File "/home/asal/src/project/pyro/poutine/trace_struct.py", line 151, in compute_log_prob
log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
File "/home/asal/anaconda2/envs/pyro-env/lib/python3.7/site-packages/torch/distributions/categorical.py", line 119, in log_prob
return log_pmf.gather(-1, value).squeeze(-1)
RuntimeError: Invalid index in gather at /opt/conda/conda-bld/pytorch-nightly-cpu_1540796783109/work/aten/src/TH/generic/THTensorEvenMoreMath.cpp:453
```

You can also find the complete offending source file here: