Following the post here, I’m trying to marginalize out a continuous latent variable in my model so my guide is only for “global” parameters.
I’m generating synthetic data in the following way: choose cluster 1 w/ probability .7, otherwise choose cluster 2. Each cluster is defined by a pair of parameters of a beta distribution; once you have chosen a cluster, then draw a sample p from its beta distribution and use it as the parameter of a Bernoulli distribution. Then generate a sequence of 50 Bernoulli observations. Repeat this whole process N=4000 times to get the dataset.
My goal is to (i) infer the parameters of each beta distribution, as well as their relative frequencies, and (ii) make inferences on new data. To do this, I understand I need to marginalize out the latent variables ‘p’ so that they don’t appear in the guide, otherwise the guide is tied to the training dataset and I can’t do inference with it. Following the link above, I annotate the site ‘p’ with instructions for approximation via sampling, i.e.
p = pyro.sample('p', dist.Beta(a[assignment], b[assignment]), infer={
"enumerate" : "parallel", "expand": True, "num_samples": 100}).
Here is the code, both for generating the synthetic data, and running the model.
Synthetic data;
# imports and generate synthetic data
import torch
import numpy as np
from tqdm import tqdm
import pandas as pd
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete, Trace_ELBO
from pyro.ops.indexing import Vindex
a1, b1 = 5, 1
a2, b2 = 1, 5
pi = .7
data = []
ids = []
for _ in range(4000):
choice = np.random.uniform()
if choice > pi:
choice = 1
a, b = a1, b1
else:
choice = 0
a, b = a2, b2
ids.append(choice)
p = np.random.beta(a, b)
data.append(np.random.binomial(1, p, size=50))
data = torch.tensor(data).float()
ids = torch.tensor(ids).unsqueeze(-1)
Running the model:
K = 2
pyro.clear_param_store()
def model(data, ids=None):
weights = pyro.sample(
'weights', dist.Dirichlet(1./K * torch.ones(K)))
with pyro.plate('components', K):
a = pyro.sample('a', dist.InverseGamma(0.5, 1))
b = pyro.sample('b', dist.InverseGamma(0.5, 1))
with pyro.plate('data', len(data), dim=-2):
assignment = pyro.sample('assignment', dist.Categorical(weights), obs=ids)
p = pyro.sample('p', dist.Beta(a[assignment], b[assignment]), infer={
"enumerate" : "parallel", "expand": True, "num_samples": 100})
with pyro.plate('trials', 50, dim=-1):
samples = pyro.sample('obs', dist.Bernoulli(p), obs=data)
def initialize(seed):
global global_guide, svi
pyro.set_rng_seed(seed)
pyro.clear_param_store()
global_guide = AutoDiagonalNormal(poutine.block(model, hide=['assignment', 'p']))#, init_loc_fn=init_loc_fn)
svi = SVI(model, global_guide, optim, loss=elbo)
return svi.loss(model, global_guide, data, ids)
optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO()
# Choose the best among 100 random initializations.
loss, seed = min((initialize(seed), seed) for seed in tqdm(list(range(0, 10))))
initialize(seed)
print('seed = {}, initial_loss = {}'.format(seed, loss))
The issue is, I’m getting a warning of which I don’t know the source:
RuntimeWarning: Site p is multiply sampled in model,
expect incorrect gradient estimates from TraceEnum_ELBO.
Consider using exact enumeration or guide sampling if possible.
Indeed, the results are pretty mediocre. Note that assignment
is not enumerated because it is observed, so ‘p’ is the only site where anything interesting is going on.