Discrete enumeration in super simple toy model

Ok I’m not yet sure about the math behind it but I think his Hidden -> Observed model is too simple for enumeration to overcome the label switching problem. I was able to use enumeration to recover CPDs for the slightly bigger model of O -> H -> O.

I’ve attached the full notebook with this model. I think the Pyro example docs would greatly benefit from an example on a very simplified model like this, as all existing examples are on models that are complicated in their own right, making it hard to learn about just the enumeration and mixture model aspect.

Notebook
import numpy as np
import torch
from torch.distributions.bernoulli import Bernoulli
from torch.distributions.beta import Beta

n = 1000

# domain = [False, True]
prior = {'A': torch.tensor([2., 4.]),
         'B': torch.tensor([[5., 3.],
                            [2., 4.]]),
         'C': torch.tensor([[5., 3.],
                            [2., 4.]]),
        }

p = {'p_A': Beta(prior['A'][0], prior['A'][1]).sample(),
     'p_B': Beta(prior['B'][:, 0], prior['B'][:, 1]).sample(),
     'p_C': Beta(prior['C'][:, 0], prior['C'][:, 1]).sample(),
    }

data = {}
data['A'] = Bernoulli(torch.ones(n) * p['p_A']).sample()
data['B'] = Bernoulli(torch.gather(p['p_B'], 0, data['A'].type(torch.long))).sample()
data['C'] = Bernoulli(torch.gather(p['p_C'], 0, data['B'].type(torch.long))).sample()

for name, val in p.items():
    print(name, val)

import pyro
import pyro.distributions as dist
from pyro.infer import config_enumerate
from torch.distributions import constraints
import torch
from torch.distributions import constraints
from pyro.ops.indexing import Vindex

@config_enumerate
def model(obs, n):

    p_A = pyro.sample('p_A', dist.Beta(1, 1))
    
    p_B = pyro.sample('p_B', dist.Beta(torch.ones(2), torch.ones(2)).to_event(1))
    
    p_C = pyro.sample('p_C', dist.Beta(torch.ones(2), torch.ones(2)).to_event(1))
    
    with pyro.plate('data_plate', n):
        A = pyro.sample('A', dist.Bernoulli(p_A.expand(n)), obs=obs['A'])
    
        B = pyro.sample('B', dist.Bernoulli(Vindex(p_B)[A.type(torch.long)]), infer={"enumerate": "parallel"})
        
        pyro.sample('C', dist.Bernoulli(Vindex(p_C)[B.type(torch.long)]), obs=obs['C'])

def guide(obs, n):
    
    a = pyro.param('a', prior['A'], constraint=constraints.positive)
    p_A = pyro.sample('p_A', dist.Beta(a[0], a[1]))
    
    b = pyro.param('b', prior['B'], constraint=constraints.positive)
    pyro.sample('p_B', dist.Beta(b[:, 0], b[:, 1]).to_event(1))
    
    c = pyro.param('c', prior['C'], constraint=constraints.positive)
    pyro.sample('p_C', dist.Beta(c[:, 0], c[:, 1]).to_event(1))

import pyro.optim
import pyro.infer
import time
import matplotlib.pyplot as plt

pyro.enable_validation(True)
pyro.clear_param_store()

# setup svi object
loss_func = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1)
optim = pyro.optim.Adam({"lr": .001})
svi = pyro.infer.SVI(model, guide, optim, loss=loss_func)

# perform svi
num_steps = 30000
losses = []
start = time.time()
for step in range(num_steps):
    loss = svi.step(data, n)
    losses.append(loss)

    if step % (num_steps*.33) == 0:
        print(step, f'({(time.time() - start)/60:.1f} min.)')
print(step+1, f'({(time.time() - start)/60:.1f} min.)\n\n')

fig = plt.figure()
plt.plot(losses)
plt.show()

posterior_params = {k: np.array(v.data) for k, v in pyro.get_param_store().items()}
posterior_params['a'] = posterior_params['a'][None, :]
for key, val in posterior_params.items():
    true_p = p[f'p_{key.upper()}'].numpy()
    print(f'p_{key.upper()}  (true/pred): ')
    print('\t', np.round(true_p, 2))
    print('\t', np.round(val[:, 0]/(np.sum(val, axis=1)), 2))