Hi. Lot’s of discrete enumeration posts on here and the Pyro docs site, but for relatively complicated models. I’m trying and failing to get it working for the toy model of just two nodes: A
is a hidden Bernoulli variable, and B
is its observed child, a mixture of two Bernoulli
distributions chosen by A
being true or false. There’s a plate over A
and B
for n
independent observations of data.
MWE is below, after lots of small tweaks from reading posts/tutorials, but it doesn’t seem to be learning anything given the dummy data. [Note1]
This model is exactly where marginalizing out over the hidden variable should work, right? [Note2] Am I doing something with Pyro’s enumeration incorrectly? Any pointers would be appreciated.
Note1: I have stochastically simulated toy data in my full notebook.
Note2: I also can’t get it working where A
is observed and B
is hidden.
import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.infer import config_enumerate
import pyro.optim
import pyro.infer
from pyro.ops.indexing import Vindex
@config_enumerate
def model(obs, n):
p_A = pyro.sample('p_A', dist.Beta(1, 1))
p_B = pyro.sample('p_B', dist.Beta(torch.ones(2), torch.ones(2)).to_event(1))
with pyro.plate('data_plate', n):
A = pyro.sample('A', dist.Bernoulli(p_A.expand(n)), infer={"enumerate": "parallel"})
B = pyro.sample('B', dist.Bernoulli(Vindex(p_B)[A.type(torch.long)]), obs=obs['B'])
def guide(obs, n):
a = pyro.param('a', torch.ones(2), constraint=constraints.positive)
p_A = pyro.sample('p_A', dist.Beta(a[0], a[1]))
b = pyro.param('b', torch.ones(2,2), constraint=constraints.positive)
pyro.sample('p_B', dist.Beta(b[:, 0], b[:, 1]).to_event(1))
n = 20
data = {'A': torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, # p_A = .5
1, 1, 1, 1, 1, 1, 1, 1 , 1, 1]).type(torch.float),
'B': torch.tensor([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, # p_B given A false = .7
0, 0, 0, 0, 0, 0, 1, 1, 1, 1]).type(torch.float)} # p_B given A true = .4
pyro.enable_validation(True)
pyro.clear_param_store()
loss_func = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1)
optim = pyro.optim.Adam({"lr": .001})
svi = pyro.infer.SVI(model, guide, optim, loss=loss_func)
for step in range(10000):
svi.step(data, n)
posterior_params = {k: v.data for k, v in pyro.get_param_store().items()}
posterior_params['a'] = posterior_params['a'][None, :]
for key, val in posterior_params.items():
print(f'p_{key.upper()} {val[:, 0]/(torch.sum(val, axis=1))}')