Posterior Predictive Class API Question

Hi all, very new to Pyro and variational inference. I tried creating a trivial example to get familiar with the apis and verify my understanding of what’s going on. In the example, I generate data from a normal distribution and use VI to infer the mean and variance (I know there is a simple analytical solution here):

import torch
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

def model(data):
  # Priors
  mu = pyro.sample("latent_mu", dist.Normal(loc=0.0, scale=1.0))
  sigma = pyro.sample("latent_sigma", dist.Gamma(concentration=1.0, rate=1.0))

  # Likelihood
  with pyro.plate('observe_data'):
    return pyro.sample("obs", dist.Normal(loc=mu, scale=sigma), obs=data)

def guide(data):
  #Trainable parameters of the variational distribution
  a_q = pyro.param("a", torch.tensor(3.0), constraint=constraints.positive)
  b_q = pyro.param("b", torch.tensor(3.0), constraint=constraints.positive)
  mu_q = pyro.param("mu", torch.tensor(0.0))

  # Variational distribution
  latent_sigma = pyro.sample("latent_sigma", dist.InverseGamma(concentration=a_q, rate=b_q))
  pyro.sample("latent_mu", dist.Normal(mu_q, latent_sigma))

# Generate some random data
global_sigma = pyro.sample("true_sigma", dist.InverseGamma(concentration=10.0, rate=10.0))
global_mu = pyro.sample("true_mu", dist.Normal(loc=5.0, scale=global_sigma))
print(global_mu, global_sigma)
normal_data = pyro.sample("data", dist.Normal(loc=global_mu, scale=global_sigma), sample_shape=[1000])

fig, ax = plt.subplots(1, 1)
ax.hist(normal_data, bins=100)
plt.show()

# Run VI to fit the variational parameters
pyro.clear_param_store()
svi = pyro.infer.SVI(model=model,
                     guide=guide,
                     optim=pyro.optim.Adam({"lr": 0.003}),
                     loss=pyro.infer.Trace_ELBO())


losses, mus, a_qs, b_qs = [], [], [], []
num_steps = 20000
for t in range(num_steps):
    losses.append(svi.step(normal_data))
    mus.append(pyro.param("mu").item())
    a_qs.append(pyro.param("b").item())
    b_qs.append(pyro.param("b").item())

fig, ax = plt.subplots(4, 1)
ax[0].plot(losses)
ax[0].set_title("ELBO")
ax[0].set_xlabel("step")
ax[0].set_ylabel("loss")

ax[1].plot(mus)
ax[1].set_title("Mu")
ax[1].set_xlabel("step")
ax[1].set_ylabel("Mu")

ax[2].plot(a_qs)
ax[2].set_title("a")
ax[2].set_xlabel("step")
ax[2].set_ylabel("a")

ax[3].plot(b_qs)
ax[3].set_title("b")
ax[3].set_xlabel("step")
ax[3].set_ylabel("b")

print('mu = ', pyro.param("mu").item())

When I go to get posterior predictions, I can get posterior samples of the latent parameters but whatever I pass in to an instance of Predictor just gets echoed back in the obs site. There doesn’t appear to be any sampling being done at all for the posterior predictive:

from pyro.infer import Predictive

num_samples = 100
predictive = Predictive(model=model, guide=guide, num_samples=num_samples)
posterior_samples = predictive(normal_data)
mu = posterior_samples['latent_mu'].detach().cpu().numpy()
sigma = posterior_samples['latent_sigma'].detach().cpu().numpy()
x = posterior_samples['obs'].detach().cpu().numpy()

# x is a tensor of shape [100,1,1000] with each batch entry an identical copy of `normal_data`
# if you pass in torch.tensor(1) to `predictive` you get a tensor of all ones, not a sample from the # posterior predictive distribution

Hi @kenleejr92, the way you’ve written your model, obs will always be observed. Predictive does not remove observations from a model; you must do this yourself.

One way you could accomplish this is to remove the data argument to model and instead use pyro.condition at training time to pass observations to your model:

def model():  # no data argument
    ...
        return pyro.sample("obs", Normal(loc, scale)) # no obs= argument

def guide():  # no data argument
    ...  # same as before

At training time:

conditioned_model = pyro.condition(model, data={"obs": data})
svi = pyro.infer.SVI(conditioned_model, guide, ...)
...
svi.step()  # model takes no arguments 

Then when testing, just use the unconditioned model:

predictive = Predictive(model=model, guide=guide, num_samples=num_samples)
posterior_samples = predictive()  # model takes no arguments
1 Like