# Posterior Predictive Class API Question

Hi all, very new to Pyro and variational inference. I tried creating a trivial example to get familiar with the apis and verify my understanding of what’s going on. In the example, I generate data from a normal distribution and use VI to infer the mean and variance (I know there is a simple analytical solution here):

``````import torch
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

def model(data):
# Priors
mu = pyro.sample("latent_mu", dist.Normal(loc=0.0, scale=1.0))
sigma = pyro.sample("latent_sigma", dist.Gamma(concentration=1.0, rate=1.0))

# Likelihood
with pyro.plate('observe_data'):
return pyro.sample("obs", dist.Normal(loc=mu, scale=sigma), obs=data)

def guide(data):
#Trainable parameters of the variational distribution
a_q = pyro.param("a", torch.tensor(3.0), constraint=constraints.positive)
b_q = pyro.param("b", torch.tensor(3.0), constraint=constraints.positive)
mu_q = pyro.param("mu", torch.tensor(0.0))

# Variational distribution
latent_sigma = pyro.sample("latent_sigma", dist.InverseGamma(concentration=a_q, rate=b_q))
pyro.sample("latent_mu", dist.Normal(mu_q, latent_sigma))

# Generate some random data
global_sigma = pyro.sample("true_sigma", dist.InverseGamma(concentration=10.0, rate=10.0))
global_mu = pyro.sample("true_mu", dist.Normal(loc=5.0, scale=global_sigma))
print(global_mu, global_sigma)
normal_data = pyro.sample("data", dist.Normal(loc=global_mu, scale=global_sigma), sample_shape=[1000])

fig, ax = plt.subplots(1, 1)
ax.hist(normal_data, bins=100)
plt.show()

# Run VI to fit the variational parameters
pyro.clear_param_store()
svi = pyro.infer.SVI(model=model,
guide=guide,
loss=pyro.infer.Trace_ELBO())

losses, mus, a_qs, b_qs = [], [], [], []
num_steps = 20000
for t in range(num_steps):
losses.append(svi.step(normal_data))
mus.append(pyro.param("mu").item())
a_qs.append(pyro.param("b").item())
b_qs.append(pyro.param("b").item())

fig, ax = plt.subplots(4, 1)
ax[0].plot(losses)
ax[0].set_title("ELBO")
ax[0].set_xlabel("step")
ax[0].set_ylabel("loss")

ax[1].plot(mus)
ax[1].set_title("Mu")
ax[1].set_xlabel("step")
ax[1].set_ylabel("Mu")

ax[2].plot(a_qs)
ax[2].set_title("a")
ax[2].set_xlabel("step")
ax[2].set_ylabel("a")

ax[3].plot(b_qs)
ax[3].set_title("b")
ax[3].set_xlabel("step")
ax[3].set_ylabel("b")

print('mu = ', pyro.param("mu").item())
``````

When I go to get posterior predictions, I can get posterior samples of the latent parameters but whatever I pass in to an instance of `Predictor` just gets echoed back in the `obs` site. There doesn’t appear to be any sampling being done at all for the posterior predictive:

``````from pyro.infer import Predictive

num_samples = 100
predictive = Predictive(model=model, guide=guide, num_samples=num_samples)
posterior_samples = predictive(normal_data)
mu = posterior_samples['latent_mu'].detach().cpu().numpy()
sigma = posterior_samples['latent_sigma'].detach().cpu().numpy()
x = posterior_samples['obs'].detach().cpu().numpy()

# x is a tensor of shape [100,1,1000] with each batch entry an identical copy of `normal_data`
# if you pass in torch.tensor(1) to `predictive` you get a tensor of all ones, not a sample from the # posterior predictive distribution``````

Hi @kenleejr92, the way you’ve written your model, `obs` will always be observed. `Predictive` does not remove observations from a model; you must do this yourself.

One way you could accomplish this is to remove the `data` argument to model and instead use `pyro.condition` at training time to pass observations to your model:

``````def model():  # no data argument
...
return pyro.sample("obs", Normal(loc, scale)) # no obs= argument

def guide():  # no data argument
...  # same as before
``````

At training time:

``````conditioned_model = pyro.condition(model, data={"obs": data})
svi = pyro.infer.SVI(conditioned_model, guide, ...)
...
svi.step()  # model takes no arguments
``````

Then when testing, just use the unconditioned model:

``````predictive = Predictive(model=model, guide=guide, num_samples=num_samples)
posterior_samples = predictive()  # model takes no arguments
``````
1 Like