Hi all, very new to Pyro and variational inference. I tried creating a trivial example to get familiar with the apis and verify my understanding of what’s going on. In the example, I generate data from a normal distribution and use VI to infer the mean and variance (I know there is a simple analytical solution here):
import torch
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')
def model(data):
# Priors
mu = pyro.sample("latent_mu", dist.Normal(loc=0.0, scale=1.0))
sigma = pyro.sample("latent_sigma", dist.Gamma(concentration=1.0, rate=1.0))
# Likelihood
with pyro.plate('observe_data'):
return pyro.sample("obs", dist.Normal(loc=mu, scale=sigma), obs=data)
def guide(data):
#Trainable parameters of the variational distribution
a_q = pyro.param("a", torch.tensor(3.0), constraint=constraints.positive)
b_q = pyro.param("b", torch.tensor(3.0), constraint=constraints.positive)
mu_q = pyro.param("mu", torch.tensor(0.0))
# Variational distribution
latent_sigma = pyro.sample("latent_sigma", dist.InverseGamma(concentration=a_q, rate=b_q))
pyro.sample("latent_mu", dist.Normal(mu_q, latent_sigma))
# Generate some random data
global_sigma = pyro.sample("true_sigma", dist.InverseGamma(concentration=10.0, rate=10.0))
global_mu = pyro.sample("true_mu", dist.Normal(loc=5.0, scale=global_sigma))
print(global_mu, global_sigma)
normal_data = pyro.sample("data", dist.Normal(loc=global_mu, scale=global_sigma), sample_shape=[1000])
fig, ax = plt.subplots(1, 1)
ax.hist(normal_data, bins=100)
plt.show()
# Run VI to fit the variational parameters
pyro.clear_param_store()
svi = pyro.infer.SVI(model=model,
guide=guide,
optim=pyro.optim.Adam({"lr": 0.003}),
loss=pyro.infer.Trace_ELBO())
losses, mus, a_qs, b_qs = [], [], [], []
num_steps = 20000
for t in range(num_steps):
losses.append(svi.step(normal_data))
mus.append(pyro.param("mu").item())
a_qs.append(pyro.param("b").item())
b_qs.append(pyro.param("b").item())
fig, ax = plt.subplots(4, 1)
ax[0].plot(losses)
ax[0].set_title("ELBO")
ax[0].set_xlabel("step")
ax[0].set_ylabel("loss")
ax[1].plot(mus)
ax[1].set_title("Mu")
ax[1].set_xlabel("step")
ax[1].set_ylabel("Mu")
ax[2].plot(a_qs)
ax[2].set_title("a")
ax[2].set_xlabel("step")
ax[2].set_ylabel("a")
ax[3].plot(b_qs)
ax[3].set_title("b")
ax[3].set_xlabel("step")
ax[3].set_ylabel("b")
print('mu = ', pyro.param("mu").item())
When I go to get posterior predictions, I can get posterior samples of the latent parameters but whatever I pass in to an instance of Predictor
just gets echoed back in the obs
site. There doesn’t appear to be any sampling being done at all for the posterior predictive:
from pyro.infer import Predictive
num_samples = 100
predictive = Predictive(model=model, guide=guide, num_samples=num_samples)
posterior_samples = predictive(normal_data)
mu = posterior_samples['latent_mu'].detach().cpu().numpy()
sigma = posterior_samples['latent_sigma'].detach().cpu().numpy()
x = posterior_samples['obs'].detach().cpu().numpy()
# x is a tensor of shape [100,1,1000] with each batch entry an identical copy of `normal_data`
# if you pass in torch.tensor(1) to `predictive` you get a tensor of all ones, not a sample from the # posterior predictive distribution