Hi everyone I’ve followed the tutorial “SVI Part I” with no trouble.
I tried to take a step further and add a bit of complexity.
That’s the set up of the problem:
I have one Bernoulli distribution with parameter “phi” from which I sample a latent variable “z”.
Then I have 2 Gaussians parameterized respectively with (mu-0, std-0) and (mu-1, std-1).
The generative process sample z from the Bernoulli and than if (z == 0) it samples x from the Gaussian-0 else it samples x from the Gaussian-1.
I want to infer phi, (mu-0, std-0) and (mu-1, std-1) from the observed x.
This is just a simple GMM that I’ve easily solved with the EM algorithm but now I’m a bit confused on how to implement this in Pyro and solve it with Variational Inference.
My current implementation doesn’t seem to work.
My code:
import pyro
import pyro.optim as optim
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
import torch
import torch.distributions.constraints as constraints
def model(data):
# priors
phi = pyro.sample('phi', dist.Beta(torch.tensor(10.0), torch.tensor(10.0)))
mu_0 = pyro.sample('mu-0', dist.Normal(torch.tensor(2.0), torch.tensor(1.0)))
mu_1 = pyro.sample('mu-1', dist.Normal(torch.tensor(5.0), torch.tensor(1.0)))
std_0 = pyro.sample('std-0', dist.HalfCauchy(scale=torch.tensor(1.)))
std_1 = pyro.sample('std-1', dist.HalfCauchy(scale=torch.tensor(1.)))
for i in range(len(data)):
z = pyro.sample(f'z-{i}', dist.Bernoulli(phi))
if z.long().item() == 0:
pyro.sample(f'obs-{i}', dist.Normal(mu_0, std_0), obs=data[i])
else:
pyro.sample(f'obs-{i}', dist.Normal(mu_1, std_1), obs=data[i])
def guide(data):
phi_alpha = pyro.param('phi-alpha', torch.tensor(10.0))
phi_beta = pyro.param('phi-beta', torch.tensor(10.0))
mu_0_mu = pyro.param('mu-0-mu', torch.tensor(2.0))
mu_0_std = pyro.param('mu-0-std', torch.tensor(1.0), constraint=constraints.positive)
mu_1_mu = pyro.param('mu-1-mu', torch.tensor(5.0))
mu_1_std = pyro.param('mu-1-std', torch.tensor(1.0), constraint=constraints.positive)
std_0_std = pyro.param('std-0-std', torch.tensor(1.0), constraint=constraints.positive)
std_1_std = pyro.param('std-1-std', torch.tensor(1.0), constraint=constraints.positive)
pyro.sample('phi', dist.Beta(phi_alpha, phi_beta))
pyro.sample('mu-0', dist.Normal(mu_0_mu, mu_0_std))
pyro.sample('mu-1', dist.Normal(mu_1_mu, mu_1_std))
pyro.sample('std-0', dist.HalfCauchy(scale=std_0_std))
pyro.sample('std-1', dist.HalfCauchy(scale=std_1_std))
# Data Generating Process #
z_dist = torch.distributions.Bernoulli(torch.tensor(0.75))
x_dists = [torch.distributions.Normal(torch.tensor(2.), torch.tensor(1.)),
torch.distributions.Normal(torch.tensor(5.), torch.tensor(1.8)),]
z_sample = z_dist.sample((100,))
x_sample = [x_dists[int(z)].sample() for z in z_sample]
pyro.clear_param_store()
pyro.enable_validation(True)
svi = SVI(model, guide,
optim=optim.ClippedAdam({'lr': 0.01}),
loss=Trace_ELBO())
c = 0
for step in range(1000):
c += 1
loss = svi.step(x_sample)
if step % 100 == 0:
phi_alpha = pyro.param('phi-alpha').item()
phi_beta = pyro.param('phi-beta').item()
mu_0_mu = pyro.param('mu-0-mu').item()
mu_1_mu = pyro.param('mu-1-mu').item()
phi = phi_alpha / (phi_alpha + phi_beta)
print("[iteration {:>4}] loss: {:.4f} | phi: {:.2f}, mu-0: {:.2f}, mu-1: {:.2f}".format(c, loss, phi, mu_0_mu, mu_1_mu))
Also I get this warning:
/home/fabio/miniconda3/envs/Variational-Inference_pytorch/lib/python3.9/site-packages/pyro/util.py:244: UserWarning: Found vars in model but not guide: {'z-96', 'z-43', 'z-33', 'z-60', 'z-42', 'z-7', 'z-91', 'z-51', 'z-47', 'z-49', 'z-24', 'z-88', 'z-10', 'z-2', 'z-67', 'z-77', 'z-83', 'z-23', 'z-62', 'z-97', 'z-35', 'z-20', 'z-4', 'z-8', 'z-27', 'z-30', 'z-34', 'z-85', 'z-82', 'z-39', 'z-45', 'z-32', 'z-66', 'z-74', 'z-78', 'z-71', 'z-89', 'z-29', 'z-11', 'z-76', 'z-18', 'z-81', 'z-5', 'z-99', 'z-61', 'z-57', 'z-95', 'z-68', 'z-50', 'z-25', 'z-37', 'z-93', 'z-16', 'z-48', 'z-9', 'z-53', 'z-65', 'z-40', 'z-44', 'z-13', 'z-14', 'z-12', 'z-72', 'z-75', 'z-64', 'z-31', 'z-87', 'z-19', 'z-98', 'z-94', 'z-46', 'z-63', 'z-90', 'z-6', 'z-52', 'z-3', 'z-22', 'z-17', 'z-59', 'z-21', 'z-84', 'z-92', 'z-28', 'z-79', 'z-26', 'z-36', 'z-70', 'z-38', 'z-55', 'z-69', 'z-1', 'z-80', 'z-58', 'z-73', 'z-0', 'z-56', 'z-86', 'z-54', 'z-41', 'z-15'}
warnings.warn("Found vars in model but not guide: {}".format(model_vars - guide_vars - enum_vars))