Discrete enumeration in super simple toy model

Hi. Lot’s of discrete enumeration posts on here and the Pyro docs site, but for relatively complicated models. I’m trying and failing to get it working for the toy model of just two nodes: A is a hidden Bernoulli variable, and B is its observed child, a mixture of two Bernoulli distributions chosen by A being true or false. There’s a plate over A and B for n independent observations of data.

MWE is below, after lots of small tweaks from reading posts/tutorials, but it doesn’t seem to be learning anything given the dummy data. [Note1]

This model is exactly where marginalizing out over the hidden variable should work, right? [Note2] Am I doing something with Pyro’s enumeration incorrectly? Any pointers would be appreciated.

Note1: I have stochastically simulated toy data in my full notebook.
Note2: I also can’t get it working where A is observed and B is hidden.

import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.infer import config_enumerate
import pyro.optim
import pyro.infer
from pyro.ops.indexing import Vindex


@config_enumerate
def model(obs, n):

    p_A = pyro.sample('p_A', dist.Beta(1, 1))
    
    p_B = pyro.sample('p_B', dist.Beta(torch.ones(2), torch.ones(2)).to_event(1))
    
    with pyro.plate('data_plate', n):
        A = pyro.sample('A', dist.Bernoulli(p_A.expand(n)), infer={"enumerate": "parallel"})

        B = pyro.sample('B', dist.Bernoulli(Vindex(p_B)[A.type(torch.long)]), obs=obs['B'])

def guide(obs, n):
    
    a = pyro.param('a', torch.ones(2), constraint=constraints.positive)
    p_A = pyro.sample('p_A', dist.Beta(a[0], a[1]))
    
    b = pyro.param('b', torch.ones(2,2), constraint=constraints.positive)
    pyro.sample('p_B', dist.Beta(b[:, 0], b[:, 1]).to_event(1))
        
n = 20
data = {'A': torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0,                     # p_A = .5
                                       1, 1, 1, 1, 1, 1, 1, 1 , 1, 1]).type(torch.float),
            'B': torch.tensor([0, 0, 0, 1, 1, 1, 1, 1, 1, 1,                     # p_B given A false = .7
                                       0, 0, 0, 0, 0, 0, 1, 1, 1, 1]).type(torch.float)} # p_B given A true = .4
        
pyro.enable_validation(True)
pyro.clear_param_store()

loss_func = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1)
optim = pyro.optim.Adam({"lr": .001})
svi = pyro.infer.SVI(model, guide, optim, loss=loss_func)

for step in range(10000):
    svi.step(data, n)

posterior_params = {k: v.data for k, v in pyro.get_param_store().items()}
posterior_params['a'] = posterior_params['a'][None, :]
for key, val in posterior_params.items():
    print(f'p_{key.upper()}  {val[:, 0]/(torch.sum(val, axis=1))}')
1 Like

Ok I’m not yet sure about the math behind it but I think his Hidden -> Observed model is too simple for enumeration to overcome the label switching problem. I was able to use enumeration to recover CPDs for the slightly bigger model of O -> H -> O.

I’ve attached the full notebook with this model. I think the Pyro example docs would greatly benefit from an example on a very simplified model like this, as all existing examples are on models that are complicated in their own right, making it hard to learn about just the enumeration and mixture model aspect.

Notebook
import numpy as np
import torch
from torch.distributions.bernoulli import Bernoulli
from torch.distributions.beta import Beta

n = 1000

# domain = [False, True]
prior = {'A': torch.tensor([2., 4.]),
         'B': torch.tensor([[5., 3.],
                            [2., 4.]]),
         'C': torch.tensor([[5., 3.],
                            [2., 4.]]),
        }

p = {'p_A': Beta(prior['A'][0], prior['A'][1]).sample(),
     'p_B': Beta(prior['B'][:, 0], prior['B'][:, 1]).sample(),
     'p_C': Beta(prior['C'][:, 0], prior['C'][:, 1]).sample(),
    }

data = {}
data['A'] = Bernoulli(torch.ones(n) * p['p_A']).sample()
data['B'] = Bernoulli(torch.gather(p['p_B'], 0, data['A'].type(torch.long))).sample()
data['C'] = Bernoulli(torch.gather(p['p_C'], 0, data['B'].type(torch.long))).sample()

for name, val in p.items():
    print(name, val)

import pyro
import pyro.distributions as dist
from pyro.infer import config_enumerate
from torch.distributions import constraints
import torch
from torch.distributions import constraints
from pyro.ops.indexing import Vindex

@config_enumerate
def model(obs, n):

    p_A = pyro.sample('p_A', dist.Beta(1, 1))
    
    p_B = pyro.sample('p_B', dist.Beta(torch.ones(2), torch.ones(2)).to_event(1))
    
    p_C = pyro.sample('p_C', dist.Beta(torch.ones(2), torch.ones(2)).to_event(1))
    
    with pyro.plate('data_plate', n):
        A = pyro.sample('A', dist.Bernoulli(p_A.expand(n)), obs=obs['A'])
    
        B = pyro.sample('B', dist.Bernoulli(Vindex(p_B)[A.type(torch.long)]), infer={"enumerate": "parallel"})
        
        pyro.sample('C', dist.Bernoulli(Vindex(p_C)[B.type(torch.long)]), obs=obs['C'])

def guide(obs, n):
    
    a = pyro.param('a', prior['A'], constraint=constraints.positive)
    p_A = pyro.sample('p_A', dist.Beta(a[0], a[1]))
    
    b = pyro.param('b', prior['B'], constraint=constraints.positive)
    pyro.sample('p_B', dist.Beta(b[:, 0], b[:, 1]).to_event(1))
    
    c = pyro.param('c', prior['C'], constraint=constraints.positive)
    pyro.sample('p_C', dist.Beta(c[:, 0], c[:, 1]).to_event(1))

import pyro.optim
import pyro.infer
import time
import matplotlib.pyplot as plt

pyro.enable_validation(True)
pyro.clear_param_store()

# setup svi object
loss_func = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1)
optim = pyro.optim.Adam({"lr": .001})
svi = pyro.infer.SVI(model, guide, optim, loss=loss_func)

# perform svi
num_steps = 30000
losses = []
start = time.time()
for step in range(num_steps):
    loss = svi.step(data, n)
    losses.append(loss)

    if step % (num_steps*.33) == 0:
        print(step, f'({(time.time() - start)/60:.1f} min.)')
print(step+1, f'({(time.time() - start)/60:.1f} min.)\n\n')

fig = plt.figure()
plt.plot(losses)
plt.show()

posterior_params = {k: np.array(v.data) for k, v in pyro.get_param_store().items()}
posterior_params['a'] = posterior_params['a'][None, :]
for key, val in posterior_params.items():
    true_p = p[f'p_{key.upper()}'].numpy()
    print(f'p_{key.upper()}  (true/pred): ')
    print('\t', np.round(true_p, 2))
    print('\t', np.round(val[:, 0]/(np.sum(val, axis=1)), 2))