I have the following model:
@config_enumerate # enable enumeration due to categorical variable https://pyro.ai/examples/enumeration.html
def model(ac=None, an=None, label=None):
a_prior = pyro.sample('a_prior', Beta(torch.tensor(1.), torch.tensor(5.)))
beta_prior = pyro.sample("b_concentrations", Gamma(torch.tensor([[1., 6.], [1., 8.]]), rate=torch.tensor([0.5, 0.5])).to_event())
N = 1 if an is None else len(an)
with pyro.plate("data", N):
label = pyro.sample('label', Bernoulli(probs=a_prior), obs=label, infer={"enumerate": "parallel"}).long()
ac = pyro.sample('ac', BetaBinomial(beta_prior[label, 0], beta_prior[label, 1], total_count=an), obs=ac)
return label, ac
which I used to generate dummy data like this:
params = {
'a_prior': torch.tensor([0.1]), # class probabilities for A
'b_concentrations': torch.tensor([[1., 4.], [7., 8.]]), # beta concetnrations corresponding with each class of A
}
conditioned_model = pyro.poutine.condition(model, data=params, )
...
I clear the param store and fit the model to the dummy data with an AutoNormal
guide and find I am able to recover the assigned parameters, i.e. SVI seems to be working.
However, when I sample the posterior for label
at different param values, I always get the same probability of label=1 regardless of the values of ac or an:
grid_an = torch.tensor([1000.] * 100)
grid_af = torch.linspace(0., 1., 100)
grid_ac = (grid_af * grid_an).int()
posterior_predictive_model = pyro.infer.Predictive(model, guide=auto_guide, num_samples=5000)
ppd = posterior_predictive_model(ac=grid_ac, an=grid_an)
What am I doing wrong here?