Need some help with guide

Hi, I am in the process of learning Pyro. First of all, great extensions of pytorch, I love the API already.

I am trying to make a comparison with an easy example I can compute w/ numpy. So I imagine I have observed two data points length := [195, 185]

In pyMC3 my model looks as follows:

with pm.Model():
    # priors
    mu = pm.Normal('mu', mu=200, sd=15)
    sigma = pm.HalfCauchy('sigma', 10)
    # likelihood
    observed = pm.Normal('observed', mu=mu, sd=sigma, observed=lengths).
    # sample
    trace = pm.sample(draws=1000, chains=1)

This model samples good in PyMC3. Now I want to make a comparision with Variational Inference.
My pyro code:

def model(lengths):
    # priors
    mu = pyro.sample('mu', dist.Normal(loc=torch.tensor(200.), scale=torch.tensor(15.)))
    sigma = pyro.sample('sigma', dist.HalfCauchy(scale=torch.tensor(10.)))
    for i in pyro.plate('plate', size=2):
        pyro.sample(f'obs_{i}', dist.Normal(loc=mu, scale=sigma), obs=lengths[i])
def guide(lengths):
    # posterior mu
    mu_normal_mu = pyro.param('mu_normal_mu', torch.tensor(200.))
    mu_normal_sigma = pyro.param('mu_normal_sigma', torch.tensor(5.), 
    mu = pyro.sample('mu', dist.Normal(loc=mu_normal_mu, scale=mu_normal_sigma))
    # posterior sigma
    sigma_halfnormal_sigma = pyro.param('sigma_halfnormal_sigma', torch.tensor(5.), 
    sigma = pyro.sample('sigma', dist.HalfNormal(sigma_halfnormal_sigma))

However, If I run inference on the model and the guide, I see that the sigma variances increase until numeric overflow occurs. What am I doing wrong here?

For completeness, the inference snippet


# setup the inference algorithm
svi = SVI(model, guide, 

# do gradient steps
c = 0
for step in range(5000):
    c += 1
    loss = svi.step(torch.tensor((lengths - lengths.min()) / (lengths.max() - lengths.min()), dtype=torch.float32))
    if step % 100 == 0:
        print("[iteration %04d] loss: %.4f" % (c + 1, loss))

Turns out my gradients were way too big. By applying gradient clipping, all worked well.

svi = SVI(model, guide, 
1 Like