Trouble fitting simple normal model

Hello,

I’m having trouble getting a simple Normal model to work. I simulated distributions with a normal and put a Normal(0, 10) prior on the mean and Gamma(1, 1) prior on the standard deviation. However, the model seems stuck pretty close at the original parameters.

from torch.distributions.normal import Normal
from torch import tensor
from torch.distributions.constraints import positive

from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro import distributions as dist
from pyro import plate

import pyro


pyro.enable_validation(True)
pyro.clear_param_store()
pyro.set_rng_seed(1)

mu = 3.4
sigma = 1.8
data = Normal(mu, sigma).sample((1000,))

def model(data):
    # hyperparameters
    mu0 = tensor(0.0)
    sigma0 = tensor(10.0)

    alpha0 = tensor(1.0)
    beta0 = tensor(1.0)

    # priors
    mu = pyro.sample("mu", dist.Normal(mu0, sigma0))
    sigma = pyro.sample("sigma", dist.Gamma(alpha0, beta0))

    with plate('observe_data'):
        pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

def guide(data):
    # register the variational parameters with Pyro.
    mu_mu_q = pyro.param("mu_mu_q", tensor(0.0))
    mu_sigma_q = pyro.param("mu_sigma_q", tensor(10.0), constraint=positive)
    alpha_q = pyro.param("alpha_q", tensor(1.0), constraint=positive)
    beta_q = pyro.param("beta_q", tensor(1.0), constraint=positive)

    pyro.sample("mu", dist.Normal(mu_mu_q, mu_sigma_q))
    pyro.sample("sigma", dist.Gamma(alpha_q, beta_q))

adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

for step in range(5000):
    svi.step(data)

print(pyro.param("mu_mu_q").item())
print(pyro.param("mu_sigma_q").item())
print(pyro.param("alpha_q").item())
print(pyro.param("beta_q").item())

The model yields the final estimates:

0.024461915716528893
9.180091857910156
1.105475664138794
0.8988308310508728

Hi,

you have a very weak prior on mu (mu_sigma_q is the standard deviation, so it’s a variance of 100), a very small learning rate, and only one mc sample to estimate the ELBO and gradients. Increasing those, and running more iterations should help. It’s always a good idea to plot the loss :slight_smile:

from torch.distributions.normal import Normal
from torch import tensor
from torch.distributions.constraints import positive

from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro import distributions as dist
from pyro import plate

import pyro


pyro.enable_validation(True)
pyro.clear_param_store()
pyro.set_rng_seed(1)

mu = 3.4
sigma = 1.8
data = Normal(mu, sigma).sample((1000,))

def model(data):
    # hyperparameters
    mu0 = tensor(0.0)
    sigma0 = tensor(1.0)

    alpha0 = tensor(1.0)
    beta0 = tensor(1.0)

    # priors
    mu = pyro.sample("mu", dist.Normal(mu0, sigma0))
    sigma = pyro.sample("sigma", dist.Gamma(alpha0, beta0))

    with plate('observe_data'):
        pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

def guide(data):
    # register the variational parameters with Pyro.
    mu_mu_q = pyro.param("mu_mu_q", tensor(0.0))
    mu_sigma_q = pyro.param("mu_sigma_q", tensor(1.0), constraint=positive)
    alpha_q = pyro.param("alpha_q", tensor(1.0), constraint=positive)
    beta_q = pyro.param("beta_q", tensor(1.0), constraint=positive)

    pyro.sample("mu", dist.Normal(mu_mu_q, mu_sigma_q))
    pyro.sample("sigma", dist.Gamma(alpha_q, beta_q))

adam_params = {"lr": 0.1, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

svi = SVI(model, guide, optimizer, loss=Trace_ELBO(num_particles=10,vectorize_particles=True))

loss = []
for step in range(10000):
    loss.append(svi.step(data))

import matplotlib.pylab as plt

plt.plot(loss)
plt.yscale('log')

print(pyro.param("mu_mu_q").item())
print(pyro.param("mu_sigma_q").item())
print(pyro.param("alpha_q").item())
print(pyro.param("beta_q").item())
print(dist.Gamma(pyro.param("alpha_q").item(),pyro.param("beta_q").item()).mean)

3.3997414112091064
0.07149690389633179
44.17299270629883
22.783056259155273
tensor(1.9389)

2 Likes

Thanks @deoxy. That worked.