I’m trying to infer a discrete site from a trained model conditioned on another variable:

```
@config_enumerate
def model(A_obs=None, B_obs=None):
# A_prior = pyro.sample('A_prior', dist.Beta(torch.tensor(1.), torch.tensor(3.)))
A_prior = pyro.param('A_prior', dist.Beta(torch.tensor(1.), torch.tensor(3.)).to_event())
B_prior = pyro.sample('B_prior', dist.Gamma(torch.tensor([[1., 4.],
[2., 4.]]), rate=torch.tensor([0.5, 0.5])).to_event())
N = 1 if B_obs is None else len(B_obs)
with pyro.plate('data', N):
A = pyro.sample('A', dist.Bernoulli(probs=A_prior), obs=A_obs, infer={'enumerate': 'parallel'}).long()
B = pyro.sample('B', dist.Beta(B_prior[A, 0], B_prior[A, 1]), obs=B_obs)
return A, B
pyro.clear_param_store()
params = {
'A_prior': torch.tensor([0.5]), # class probabilities for A
'B_prior': torch.tensor([[1., 4.], [7., 8.]]), # beta concetnrations corresponding with each class of A
}
conditioned_model = pyro.poutine.condition(model, data=params, )
conditioned_predictive = Predictive(conditioned_model, posterior_samples={}, num_samples=10000)
dummy_samples = conditioned_predictive()
pyro.clear_param_store()
auto_guide = pyro.infer.autoguide.AutoNormal(model)
adam = pyro.optim.Adam({"lr": 0.01}) # Consider decreasing learning rate.
elbo = pyro.infer.TraceEnum_ELBO()
svi = pyro.infer.SVI(model, auto_guide, adam, elbo)
losses = []
for step in range(1_000): # Consider running for more steps.
loss = svi.step(dummy_samples['A'].squeeze(), dummy_samples['B'].squeeze())
losses.append(loss)
if step % 100 == 0:
print("Elbo loss: {}".format(loss))
posterior_predictive_model = pyro.infer.Predictive(model, guide=auto_guide, num_samples=50)
serving_model = infer_discrete(posterior_predictive_model, first_available_dim=-1, temperature=1)
serving_model(B_obs=torch.tensor([0.5]))
```

Which results in the following error:

```
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File ~/Library/Caches/pypoetry/virtualenvs/statistical-rethinking-64KwZK9C-py3.9/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
...
RuntimeError: Multiple sample sites named 'B_prior_unconstrained'
The above exception was the direct cause of the following exception:
...
RuntimeError: Multiple sample sites named 'B_prior_unconstrained'
Trace Shapes:
Param Sites:
AutoNormal.locs.B_prior 2 2
AutoNormal.scales.B_prior 2 2
Sample Sites:
...
RuntimeError: Multiple sample sites named 'B_prior_unconstrained'
Trace Shapes:
Param Sites:
AutoNormal.locs.B_prior 2 2
AutoNormal.scales.B_prior 2 2
Sample Sites:
Trace Shapes:
Param Sites:
AutoNormal.locs.B_prior 2 2
AutoNormal.scales.B_prior 2 2
Sample Sites:
_num_predictive_samples dist |
value 50 |
B_prior_unconstrained dist | 2 2
value | 2 2
B_prior dist | 2 2
value | 2 2
```

What am I getting wrong? Despite my reading the enumeration and shapes documentation, this remains difficult for me to parse.