 # Getting a warning when trying to marginalize a continuous latent variable

Following the post here, I’m trying to marginalize out a continuous latent variable in my model so my guide is only for “global” parameters.

I’m generating synthetic data in the following way: choose cluster 1 w/ probability .7, otherwise choose cluster 2. Each cluster is defined by a pair of parameters of a beta distribution; once you have chosen a cluster, then draw a sample p from its beta distribution and use it as the parameter of a Bernoulli distribution. Then generate a sequence of 50 Bernoulli observations. Repeat this whole process N=4000 times to get the dataset.

My goal is to (i) infer the parameters of each beta distribution, as well as their relative frequencies, and (ii) make inferences on new data. To do this, I understand I need to marginalize out the latent variables ‘p’ so that they don’t appear in the guide, otherwise the guide is tied to the training dataset and I can’t do inference with it. Following the link above, I annotate the site ‘p’ with instructions for approximation via sampling, i.e.

p = pyro.sample('p', dist.Beta(a[assignment], b[assignment]), infer={
"enumerate" : "parallel", "expand": True, "num_samples": 100}).


Here is the code, both for generating the synthetic data, and running the model.

Synthetic data;

# imports and generate synthetic data
import torch
import numpy as np
from tqdm import tqdm

import pandas as pd
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete, Trace_ELBO
from pyro.ops.indexing import Vindex

a1, b1 = 5, 1
a2, b2 = 1, 5

pi = .7
data = []
ids = []
for _ in range(4000):
choice = np.random.uniform()
if choice > pi:
choice = 1
a, b = a1, b1
else:
choice = 0
a, b = a2, b2
ids.append(choice)

p = np.random.beta(a, b)
data.append(np.random.binomial(1, p, size=50))

data = torch.tensor(data).float()
ids = torch.tensor(ids).unsqueeze(-1)


Running the model:

K = 2
pyro.clear_param_store()

def model(data, ids=None):
weights = pyro.sample(
'weights', dist.Dirichlet(1./K * torch.ones(K)))

with pyro.plate('components', K):
a = pyro.sample('a', dist.InverseGamma(0.5, 1))
b = pyro.sample('b', dist.InverseGamma(0.5, 1))

with pyro.plate('data', len(data), dim=-2):
assignment = pyro.sample('assignment', dist.Categorical(weights), obs=ids)

p = pyro.sample('p', dist.Beta(a[assignment], b[assignment]), infer={
"enumerate" : "parallel", "expand": True, "num_samples": 100})

with pyro.plate('trials', 50, dim=-1):
samples = pyro.sample('obs', dist.Bernoulli(p), obs=data)

def initialize(seed):
global global_guide, svi
pyro.set_rng_seed(seed)
pyro.clear_param_store()
global_guide = AutoDiagonalNormal(poutine.block(model, hide=['assignment', 'p']))#, init_loc_fn=init_loc_fn)
svi = SVI(model, global_guide, optim, loss=elbo)
return svi.loss(model, global_guide, data, ids)

optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO()

# Choose the best among 100 random initializations.
loss, seed = min((initialize(seed), seed) for seed in tqdm(list(range(0, 10))))
initialize(seed)
print('seed = {}, initial_loss = {}'.format(seed, loss))


The issue is, I’m getting a warning of which I don’t know the source:

RuntimeWarning: Site p is multiply sampled in model,
expect incorrect gradient estimates from TraceEnum_ELBO.
Consider using exact enumeration or guide sampling if possible.


Indeed, the results are pretty mediocre. Note that assignment is not enumerated because it is observed, so ‘p’ is the only site where anything interesting is going on.

Hi @ethansargent, you might consider replacing sites p and obs with a single pyro.distributions.BetaBinomial distribution for your observations, integrating out p analytically. You could also discretize the distribution you’re using for p via quadrature (i.e. rewrite p as p = ps[i_p] where ps are the quadrature points and i_p is drawn from a Categorical distribution whose probabilities are the quadrature weights) and use exact enumeration to marginalize i_p, which should work well since each p is one-dimensional.

The warning seems pretty clear, so I’m not sure what else you’d like to know - if you’re asking about the underlying cause, it’s that TraceEnum_ELBO is missing full support for marginalization via importance sampling from the prior and we have not felt the need to add it so far, since it’s generally not very effective in high dimensions relative to variational inference. The num_samples enumeration argument is primarily useful for extra ELBO variance reduction when sampling from a guide.

Thank you again for a detailed reply. I have one follow-up question.

It seems like the utility of marginalizing latent variables in the model is that they then do not appear in the guide. In the GMM tutorial, for example, the local variable assignment is marginalized out via enumeration, and the guide is a distribution over the global cluster parameters only. So, later, when we plug the guide into infer_discrete, we’re not passing it a guide that has some arbitrary collection of training data latents memorized.

On the other hand, if I don’t marginalize out p above and pass the guide to infer_discrete, it presumably both compares the guide (which has an arbitrary collection of n_{\text{train}} latents memorized) and the joint on the new data (n_{\text{predict}} \neq n_{\text{train}} data points) and a shape error gets bubbled up at some point - not sure how infer_discrete works under the hood.

Where I am going with this is, if marginalizing out a local latent variable is inadvisable, like above, how in general do you do inference on new data when your guide is “specific to” your training dataset?

Conceivably you could manually take the MAP estimates for your global parameters and use them in a separate predictive model, but I’m wondering if there’s a best practice I’m missing.

My point isn’t that marginalizing local variables is bad, it’s that attempting to do so by importance sampling from the prior will not be very effective except in particularly simple cases. If you can do the integrals exactly using e.g. conjugacy (as in my BetaBinomial suggestion) or enumeration (as in the case of infer_discrete or my quadrature suggestion) then you should. Otherwise, the best approach depends on your problem; one strategy is amortized variational inference in which the parameters of an approximate posterior distributions for local latent variables are functions of the data.

1 Like