Not sure if this is a bug or I am doing something unintended here. Can someone please clarify? Please see following minimal example, where single sample batch gives cryptic error.
import pyro
import torch
import pyro.distributions as dist
print(pyro.__version__)
@pyro.infer.config_enumerate
def model(x: torch.Tensor, x_mask: torch.Tensor):
with pyro.plate('batch', x.shape[0]):
pyro.sample('x', dist.Categorical(probs=torch.as_tensor([.3, .7])), obs=x, obs_mask=x_mask)
print(pyro.infer.TraceEnum_ELBO().compute_marginals(
model=model, guide=lambda *args, **kwargs: None,
x=torch.as_tensor([0, 0]), x_mask=torch.as_tensor([False, False]) # batch size == 2 works
)['x_unobserved'].probs)
print(pyro.infer.TraceEnum_ELBO().compute_marginals(
model=model, guide=lambda *args, **kwargs: None,
x=torch.as_tensor([0]), x_mask=torch.as_tensor([False]) # batch size == 1
)['x_unobserved'].probs)
1.8.1
tensor([[0.3000, 0.7000],
[0.3000, 0.7000]])
Traceback (most recent call last):
File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 2737, in safe_execfile
py3compat.execfile(
File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/IPython/utils/py3compat.py", line 55, in execfile
exec(compiler(f.read(), fname, "exec"), glob, loc)
File "/home/daniel/code/gnn/models/deep_belief_network/test_scripts/pyro_batchsize.py", line 21, in <module>
print(pyro.infer.TraceEnum_ELBO().compute_marginals(
File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/pyro/infer/traceenum_elbo.py", line 493, in compute_marginals
return _compute_marginals(model_trace, guide_trace)
File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/pyro/infer/traceenum_elbo.py", line 241, in _compute_marginals
logits = contract_to_tensor(
File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/pyro/ops/contract.py", line 274, in contract_to_tensor
return ring.broadcast(term, target_ordinal)
File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/pyro/ops/rings.py", line 80, in broadcast
missing_shape = tuple(self._dim_to_size[dim] for dim in missing_dims)
File "/home/daniel/miniconda3/envs/py310/lib/python3.10/site-packages/pyro/ops/rings.py", line 80, in <genexpr>
missing_shape = tuple(self._dim_to_size[dim] for dim in missing_dims)
KeyError: 'a'