TraceEnum_ELBO compute_marginals: KeyError for batch size 1

Not sure if this is a bug or I am doing something unintended here. Can someone please clarify? Please see following minimal example, where single sample batch gives cryptic error.

import pyro
import torch
import pyro.distributions as dist


print(pyro.__version__)

@pyro.infer.config_enumerate
def model(x: torch.Tensor, x_mask: torch.Tensor):
    with pyro.plate('batch', x.shape[0]):
        pyro.sample('x', dist.Categorical(probs=torch.as_tensor([.3, .7])), obs=x, obs_mask=x_mask)


print(pyro.infer.TraceEnum_ELBO().compute_marginals(
    model=model, guide=lambda *args, **kwargs: None,
    x=torch.as_tensor([0, 0]), x_mask=torch.as_tensor([False, False])  # batch size == 2 works
)['x_unobserved'].probs)

print(pyro.infer.TraceEnum_ELBO().compute_marginals(
    model=model, guide=lambda *args, **kwargs: None,
    x=torch.as_tensor([0]), x_mask=torch.as_tensor([False])  # batch size == 1
)['x_unobserved'].probs)
1.8.1
tensor([[0.3000, 0.7000],
        [0.3000, 0.7000]])
Traceback (most recent call last):
  File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2737, in safe_execfile
    py3compat.execfile(
  File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/IPython/utils/py3compat.py", line 55, in execfile
    exec(compiler(f.read(), fname, "exec"), glob, loc)
  File "/home/daniel/code/gnn/models/deep_belief_network/test_scripts/pyro_batchsize.py", line 21, in <module>
    print(pyro.infer.TraceEnum_ELBO().compute_marginals(
  File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/pyro/infer/traceenum_elbo.py", line 493, in compute_marginals
    return _compute_marginals(model_trace, guide_trace)
  File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/pyro/infer/traceenum_elbo.py", line 241, in _compute_marginals
    logits = contract_to_tensor(
  File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/pyro/ops/contract.py", line 274, in contract_to_tensor
    return ring.broadcast(term, target_ordinal)
  File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/pyro/ops/rings.py", line 80, in broadcast
    missing_shape = tuple(self._dim_to_size[dim] for dim in missing_dims)
  File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/pyro/ops/rings.py", line 80, in <genexpr>
    missing_shape = tuple(self._dim_to_size[dim] for dim in missing_dims)
KeyError: 'a'

Hi @dschneider, this might be related to Python 3.10 issues; you might try Python 3.9. Does the issue persist if you use nontrival batch_size, i.e. of >=2? Feel free to file a bug report.

1 Like

Hi, thanks for the quick reply. Batch sizes > 1 work fine.