I try to do Variational Inference WITHOUT Pyro only using PyTorch for understanding VI

Hi. Thanks to see this topic.

I am trying to write a code VI without pyro modules for understanding it now. I have got result of Maximum Likelihood estimation and Maximum A Posterior estimation, but I have not get good result (expected result) of VI.

I think my understanding to VI is something wrong a little.

My code is consists of below topics.
NOTE: torchdist = torch.distributions

Make toy data

I tried to infer this posterior distribution of coeffs and intercept.

def toy_poly():
    
    x = 5 * torch.rand(100, 1) 
    linear_op = -3 - 4*x + 1*x**2 
    y = torchdist.Normal(linear_op, 1).sample()
    return x, y

x_train, y_train = toy_poly()

built joint log prob

I built joint log prob consists of prior and likelihood.

def log_joint_prob(w0, w1, w2, x, y):
    
    prior_w0 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))
    prior_w1 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))
    prior_w2 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))

    linear = w0 + w1*x + w2*x**2
    likelihood = torchdist.Normal(linear, torch.ones_like(linear))
    
    return (
        prior_w0.log_prob(w0) +
        prior_w1.log_prob(w1) +
        prior_w2.log_prob(w2) +
        likelihood.log_prob(y).mean()
    )
  • setting Variational Model

This model is variational model with variational parameters. When updating the variational parameters, the latent variable samples and log prob are needed for calculation of KL divergence, I think.

variational_params = {
    "w0_loc": torch.nn.Parameter(torch.tensor(0.)),
    "w0_scale_log": torch.nn.Parameter(torch.tensor(0.)),
    "w1_loc": torch.nn.Parameter(torch.tensor(0.)),
    "w1_scale_log": torch.nn.Parameter(torch.tensor(0.)),
    "w2_loc": torch.nn.Parameter(torch.tensor(0.)),
    "w2_scale_log": torch.nn.Parameter(torch.tensor(0.)),
}

def variational_model(variational_params):
    """
    Variational model q(w; eta)
    arg: variational parameters "eta"
    return: w ~ q(w; eta)
    """
    w0_q = torchdist.Normal(
        variational_params["w0_loc"],
        torch.exp(variational_params["w0_scale_log"]),
    )
    
    w1_q = torchdist.Normal(
        variational_params["w1_loc"],
        torch.exp(variational_params["w1_scale_log"]),
    )
    
    w2_q = torchdist.Normal(
        variational_params["w2_loc"],
        torch.exp(variational_params["w2_scale_log"]),
    )
    
    return w0_q, w1_q, w2_q

KL divergence

I suspect this topic.

Samples of latent variables is needed for calculation of log variational_model(latent_sample), log likelihood(y|x, latent_sample), and prior(latent_sample). I think. The fuct is we need expected value, but we approximate the expected value with only one sample. Is this correct handling for to approximate KL?


def kl_divergence(variational_params, x, y):
    w0_q, w1_q, w2_q = variational_model(variational_params)
    
    w0_sample = w0_q.sample()
    w1_sample = w1_q.sample()    
    w2_sample = w2_q.sample()
    
    log_joint_prob_value = log_joint_prob(w0_sample, w1_sample, w2_sample, x, y)
    log_variational_prob_value = (
        w0_q.log_prob(w0_sample) +
        w1_q.log_prob(w1_sample) +
        w2_q.log_prob(w2_sample)
    )
    
    return log_variational_prob_value - log_joint_prob_value

training

Finally, this is basic training code.

optimizer = torch.optim.SGD(params=variational_params.values(), lr=1e-8)

for i in range(9000):
    optimizer.zero_grad()
    loss_value =kl_divergence(variational_params, x_train, y_train)
    loss_value.backward()
    optimizer.step()
    
    if (i+1) % 300 == 0 or (i==0):
        print(loss_value.detach().numpy())

That loss_value was unstable. I think the reason is because of KL with one sample.

Can you know what should I modify?

@hellocyber I think that you did a pretty good job in elaborate how SVI works. I can’t identify anything wrong with your code. In my opinion, KL with 1 sample is fine here given that you have set a pretty high num_steps=9000. Maybe the problem lies at data. Something likes this trick can help:

  • Set x2 = x**2.
  • Centering your variables: z = x - x.mean(), z2 = x2 - x2.mean() and doing inference with z and z2. In reporting results, your learned w1 and w2 with new variables z and z2 should be the same as when you use x and x2; your learned w0 will need to be adjusted by learned_w1 * x.mean() + learned_w2 * x2.mean().

Thanks for providing a detailed code, I’ll go over your code this weekend to see if I can help on improving its performance. In the mean time, could you please try the above suggestion? That trick helps for removing correlation of x and x**2.

@fehiepsi

Thanks to reply.
I tried to reflect your proportion. However, the result was not changed.

The code is below.

def log_joint_prob(w1, w2, x, y):
    
    z1 = x
    z2 = x**2
    
    z1 = z1 - z1.mean()
    z2 = z2 - z2.mean()
    
#      w0 = learned_w1 * x.mean() + learned_w2 * x2.mean()
#     prior_w0 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))
    prior_w1 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))
    prior_w2 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))

    linear = w1*z1 + w2*z2
    likelihood = torchdist.Normal(linear, torch.ones_like(linear))
    
    return (
#         prior_w0.log_prob(w0).mean() +
        prior_w1.log_prob(w1).mean() +
        prior_w2.log_prob(w2).mean() +
        likelihood.log_prob(y).mean()
    )


variational_params = {
#     "w0_loc": torch.nn.Parameter(torch.tensor(0.)),
#     "w0_scale_log": torch.nn.Parameter(torch.tensor(0.)),
    "w1_loc": torch.nn.Parameter(torch.tensor(0.)),
    "w1_scale_log": torch.nn.Parameter(torch.tensor(0.)),
    "w2_loc": torch.nn.Parameter(torch.tensor(0.)),
    "w2_scale_log": torch.nn.Parameter(torch.tensor(0.)),
}

def variational_model(variational_params):
    """
    Variational model q(w; eta)
    arg: variational parameters "eta"
    return: w ~ q(w; eta)
    """
#     w0_q = torchdist.Normal(
#         variational_params["w0_loc"],
#         torch.exp(variational_params["w0_scale_log"]),
#     )
    
    w1_q = torchdist.Normal(
        variational_params["w1_loc"],
        torch.exp(variational_params["w1_scale_log"]),
    )
    
    w2_q = torchdist.Normal(
        variational_params["w2_loc"],
        torch.exp(variational_params["w2_scale_log"]),
    )
    
    return w1_q, w2_q

def kl_divergence(variational_params, x, y):
    w1_q, w2_q = variational_model(variational_params)
    
    w1_sample = w1_q.sample()    
    w2_sample = w2_q.sample()
    
    log_joint_prob_value = log_joint_prob(w1_sample, w2_sample, x, y)
    log_variational_prob_value = (
        w1_q.log_prob(w1_sample) +
        w2_q.log_prob(w2_sample)
    )
    
    return log_variational_prob_value - log_joint_prob_value

When you notice something, for example wrong code, please let me know.

Hi @hellocyber, I just recognize that you use wi_q.sample() instead of rsample. PyTorch distribution’s sample method does not reparameterize. If you replace it by rsample, your original model will work (so no need to use “centering” trick to remove correlation). In addition, I think that SVI uses likelihood.log_prob(y).sum() instead of .mean(). Using .mean() might be fine here (kind of applying some scale for likelihood against KL) but might not work for more complicated models.

Hello.

I just recognize that you use wi_q.sample() instead of rsample. PyTorch distribution’s sample method does not reparameterize.

That’s a first for me! Thank you! However, as usual the result was not as expected.
Perhaps, is not my understanding of VI wrong…?

I run your code with the change from sample to rsample, mean to sum. Things work as expected. I’m not sure what’s going wrong at your side. How about changing x = 5 * torch.rand(100, 1) to x = 5 * torch.rand(100) and using a much larger learning rate?

Really? I don’t understand it.
May I have your code of jupyter at gist?

@hellocyber You can find a gist here:

There I used your code with the following changes: samplersample, learning rate 1e-81e-4, and meansum.

@fehiepsi Thank you very much. I have executed this code and then, have got the expected result. However, I don’t understand why my original code doesn’t work. If I find the reason, I will report that.