Multiple Sample Sites/ Infer Discrete

I’m trying to infer a discrete site from a trained model conditioned on another variable:

@config_enumerate
def model(A_obs=None, B_obs=None):
#     A_prior = pyro.sample('A_prior', dist.Beta(torch.tensor(1.), torch.tensor(3.)))
    A_prior = pyro.param('A_prior', dist.Beta(torch.tensor(1.), torch.tensor(3.)).to_event())
    B_prior = pyro.sample('B_prior', dist.Gamma(torch.tensor([[1., 4.],
                                                              [2., 4.]]), rate=torch.tensor([0.5, 0.5])).to_event())
    N = 1 if B_obs is None else len(B_obs)
    
    with pyro.plate('data', N):
        A = pyro.sample('A', dist.Bernoulli(probs=A_prior), obs=A_obs, infer={'enumerate': 'parallel'}).long()
        B = pyro.sample('B', dist.Beta(B_prior[A, 0], B_prior[A, 1]), obs=B_obs)
        return A, B

pyro.clear_param_store()

params = {
    'A_prior': torch.tensor([0.5]),  # class probabilities for A
    'B_prior': torch.tensor([[1., 4.], [7., 8.]]),  # beta concetnrations corresponding with each class of A
}
conditioned_model = pyro.poutine.condition(model, data=params, )

conditioned_predictive = Predictive(conditioned_model, posterior_samples={}, num_samples=10000)
dummy_samples = conditioned_predictive()

pyro.clear_param_store()

auto_guide = pyro.infer.autoguide.AutoNormal(model)
adam = pyro.optim.Adam({"lr": 0.01})  # Consider decreasing learning rate.
elbo = pyro.infer.TraceEnum_ELBO()
svi = pyro.infer.SVI(model, auto_guide, adam, elbo)

losses = []
for step in range(1_000):  # Consider running for more steps.
    loss = svi.step(dummy_samples['A'].squeeze(), dummy_samples['B'].squeeze())
    losses.append(loss)
    if step % 100 == 0:
        print("Elbo loss: {}".format(loss))

posterior_predictive_model = pyro.infer.Predictive(model, guide=auto_guide, num_samples=50)
serving_model = infer_discrete(posterior_predictive_model, first_available_dim=-1, temperature=1)

serving_model(B_obs=torch.tensor([0.5]))

Which results in the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File ~/Library/Caches/pypoetry/virtualenvs/statistical-rethinking-64KwZK9C-py3.9/lib/python3.9/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
    173 try:
--> 174     ret = self.fn(*args, **kwargs)
    175 except (ValueError, RuntimeError) as e:
...
RuntimeError: Multiple sample sites named 'B_prior_unconstrained'

The above exception was the direct cause of the following exception:
...
RuntimeError: Multiple sample sites named 'B_prior_unconstrained'
            Trace Shapes:    
             Param Sites:    
  AutoNormal.locs.B_prior 2 2
AutoNormal.scales.B_prior 2 2
            Sample Sites:    
...
RuntimeError: Multiple sample sites named 'B_prior_unconstrained'
            Trace Shapes:    
             Param Sites:    
  AutoNormal.locs.B_prior 2 2
AutoNormal.scales.B_prior 2 2
            Sample Sites:    
               Trace Shapes:         
                Param Sites:         
     AutoNormal.locs.B_prior  2 2    
   AutoNormal.scales.B_prior  2 2    
               Sample Sites:         
_num_predictive_samples dist    |    
                       value 50 |    
  B_prior_unconstrained dist    | 2 2
                       value    | 2 2
                B_prior dist    | 2 2
                       value    | 2 2

What am I getting wrong? Despite my reading the enumeration and shapes documentation, this remains difficult for me to parse.

To infer discrete sites, I think you can use this pattern rather than using Predictive.

Thanks, @fehiepsi – following that pattern (trace and replay) has been working for me.

I do have one continuing problem. I’d like a class probability (and ideally a distribution), not just a class sample. To get a point estimate of the boolean class probability, my thought was to sample the site many times and take the mean. However, plate notation doesn’t seem to work for me due to an indexed upstream variable. Changing the dim never results in a tensor with the number of samples I’d expect:

guide_trace = poutine.trace(mv_guide).get_trace(ac=dummy_ac, an=dummy_an, label=dummy_labels)  # record the globals
trained_model = poutine.replay(model, trace=guide_trace)
inferred_model = infer_discrete(trained_model, temperature=1,
                            first_available_dim=-2)
trace = poutine.trace(inferred_model)

with pyro.plate(100, dim=-2):
    preds = trace.get_trace(ac=dummy_ac, an=dummy_an).nodes["label"]["value"]

print(preds.shape)
torch.Size([100000])

I expect a tensor of shape [100000, 100]. Instead of plate notation, I can just use a for loop, but I’m guessing there is a better way to do this.
The GMM tutorial you pointed to does not demonstrate a pattern for getting a class probability.

I’m not sure. I think infer_discrete does not let you infer continuous variables. If you want to infer the probs, you can use SVI or MCMC. Otherwise, you can draw a lot of samples to calculate empirical probs.