Marginalization over local sites during MAP

I’m trying to implement a model that marginalizes over a local variable when making predictions of global variables. I’ve tried to implement the suggestion given here Getting a warning when trying to marginalize a continuous latent variable - #4 by eb8680_2 where the local variable is drawn from a discrete distribution which we can enumerate over using TraceEnum_ELBO. I then combined this idea with what is shown in the Gaussian mixture model example into the code below

import pyro
import pyro.distributions as dist
import torch
from matplotlib import pyplot as plt
from pyro import poutine
from pyro.infer import SVI, Predictive, config_enumerate, \
    TraceEnum_ELBO
from pyro.infer.autoguide import AutoDelta


@config_enumerate
def simple_model(k, y=None, n_components=100, ground_truth=False):
    if ground_truth:
        loc = pyro.sample(
            "loc", dist.Delta(torch.ones(2)).expand([2]).to_event(1)
        )
        # scale parameter
        scale = pyro.sample(
            "scale", dist.Delta(torch.ones(2) * 0.3).expand([2]).to_event(1)
        )
        L_omega = pyro.sample("L_omega", dist.Delta(torch.eye(2)))

    else:
        loc = pyro.sample(
            "loc", dist.Normal(1.0, 0.5).expand([2]).to_event(1)
        )
        # scale parameter
        scale = pyro.sample(
            "scale", dist.LogNormal(-2.0, 1.0).expand([2]).to_event(1)
        )
        L_omega = pyro.sample("L_omega", dist.LKJCholesky(2, 1.0))
        # sigma = pyro.param("sigma", torch.ones(1), constraint=constraints.positive)
    sigma = 0.01

    with pyro.plate("components", n_components):
        theta = pyro.sample(
            "theta", dist.LogNormal(0.0, 0.25).expand([2]).to_event(1)
        )

    with pyro.plate("data", len(k)):
        z = pyro.sample(
            "z", dist.Categorical(torch.ones(n_components))
        )
        lambda_ = loc + theta[z] * scale

        # construct beam matrix
        L = torch.diag_embed(lambda_) @ L_omega
        beam_matrix = L.transpose(-2, -1) @ L

        f = (1.0 + k) ** 2 * beam_matrix[..., 0, 0] + \
            2.0 * (1.0 + k) * beam_matrix[..., 0, 1] + \
            beam_matrix[..., 1, 1]

        assert torch.all(f > 0.0)
        f = torch.log(f)

        if y is not None:
            assert torch.all(y > 0.0), y
            y = torch.log(y)

        return pyro.sample(
            "obs", dist.Normal(f, sigma),
            obs=y)

# set number of components for marginalization
n_components = 100

# generate training data from ground truth model - note: outputs are in log space
test_k = torch.linspace(-2, 2, 100)
predictive = Predictive(simple_model, guide=None, num_samples=100)
gt_samples = predictive(test_k, ground_truth=True, n_components=n_components)
gt_samples["obs"] = torch.exp(gt_samples["obs"])

# get training data from ground truth samples
sample_indicies = torch.arange(-2, 100, 10)
train_K = test_k[sample_indicies]
train_Y = gt_samples["obs"][0, sample_indicies]

# visualize training data
fig, ax = plt.subplots()
ax.plot(train_K, train_Y, 'o')

# create guide function - marginalize over theta/z
guide = AutoDelta(poutine.block(simple_model, expose=["loc", "scale", "L_omega"]))

# train with TraceEnum_ELBO
num_steps = 1500
initial_lr = 0.01
gamma = 0.1  # final learning rate will be gamma * initial_lr
lrd = gamma ** (1 / num_steps)
optim = pyro.optim.ClippedAdam({'lr': initial_lr, 'lrd': lrd})
elbo = TraceEnum_ELBO(max_plate_nesting=1, num_particles=1)

svi = SVI(simple_model, guide, optim, loss=elbo)
losses = []
for i in range(num_steps):
    loss = svi.step(train_K, train_Y)
    losses.append(loss)
fig2, ax2 = plt.subplots()
ax2.plot(losses)

# examine guide
for name in pyro.get_param_store():
    print(f"{name}:{pyro.param(name)}")

#visualize results
predictive = Predictive(simple_model, guide=guide, num_samples=800)
posterior_samples = predictive(test_k)
posterior_samples["obs"] = torch.exp(posterior_samples["obs"])

c = ["C0", "C1", "C2"]
labels = ['posterior', 'ground_truth']
for idx, samples in enumerate([posterior_samples, gt_samples]):
    mean = torch.mean(samples["obs"], dim=0)
    l = torch.quantile(samples["obs"], 0.05, dim=0)
    u = torch.quantile(samples["obs"], 0.95, dim=0)
    w = u - l

    ax.plot(test_k, mean, c=c[idx])
    ax.plot(test_k, w, c=c[idx], ls='--')

    ax.fill_between(test_k, l, u, alpha=0.25, fc=c[idx], label=labels[idx])

ax.legend()
plt.show()

In this case, since I’m marginalizing out the local parameter theta this variable should not show up in the guide used for SVI. I made it so that this site is not exposed in the AutoDelta guide function but as expected it produces a warning since the guide and model functions do match. It runs, and gives imperfect results, but I’m not sure how to implement it correctly. How do I specify this correctly?

Can you clarify the behavior you are expecting and the exact error or warning you are actually seeing?

If I understand correctly, a simpler and likely more effective alternative to your current approach would be just using a non-Delta autoguide for theta (via e.g. AutoGuideList), then drawing multiple posterior predictive samples and discarding the values of theta in those samples.

The warning was “UserWarning: Found vars in model but not guide: {‘theta’}
warnings.warn(f"Found vars in model but not guide: {bad_sites}”). I think you understand the problem correctly, so I tried your solution and it seemed to produce similar results. But to be safe I’ll do what you suggested. What happens if there exist sites in the model that are not in the guide?