Hi. Thanks to see this topic.
I am trying to write a code VI without pyro modules for understanding it now. I have got result of Maximum Likelihood estimation and Maximum A Posterior estimation, but I have not get good result (expected result) of VI.
I think my understanding to VI is something wrong a little.
My code is consists of below topics.
NOTE: torchdist = torch.distributions
Make toy data
I tried to infer this posterior distribution of coeffs and intercept.
def toy_poly():
x = 5 * torch.rand(100, 1)
linear_op = -3 - 4*x + 1*x**2
y = torchdist.Normal(linear_op, 1).sample()
return x, y
x_train, y_train = toy_poly()
built joint log prob
I built joint log prob consists of prior and likelihood.
def log_joint_prob(w0, w1, w2, x, y):
prior_w0 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))
prior_w1 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))
prior_w2 = torchdist.Normal(torch.tensor(0.), 10*torch.tensor(1.))
linear = w0 + w1*x + w2*x**2
likelihood = torchdist.Normal(linear, torch.ones_like(linear))
return (
prior_w0.log_prob(w0) +
prior_w1.log_prob(w1) +
prior_w2.log_prob(w2) +
likelihood.log_prob(y).mean()
)
- setting Variational Model
This model is variational model with variational parameters. When updating the variational parameters, the latent variable samples and log prob are needed for calculation of KL divergence, I think.
variational_params = {
"w0_loc": torch.nn.Parameter(torch.tensor(0.)),
"w0_scale_log": torch.nn.Parameter(torch.tensor(0.)),
"w1_loc": torch.nn.Parameter(torch.tensor(0.)),
"w1_scale_log": torch.nn.Parameter(torch.tensor(0.)),
"w2_loc": torch.nn.Parameter(torch.tensor(0.)),
"w2_scale_log": torch.nn.Parameter(torch.tensor(0.)),
}
def variational_model(variational_params):
"""
Variational model q(w; eta)
arg: variational parameters "eta"
return: w ~ q(w; eta)
"""
w0_q = torchdist.Normal(
variational_params["w0_loc"],
torch.exp(variational_params["w0_scale_log"]),
)
w1_q = torchdist.Normal(
variational_params["w1_loc"],
torch.exp(variational_params["w1_scale_log"]),
)
w2_q = torchdist.Normal(
variational_params["w2_loc"],
torch.exp(variational_params["w2_scale_log"]),
)
return w0_q, w1_q, w2_q
KL divergence
I suspect this topic.
Samples of latent variables is needed for calculation of log variational_model(latent_sample), log likelihood(y|x, latent_sample), and prior(latent_sample). I think. The fuct is we need expected value, but we approximate the expected value with only one sample. Is this correct handling for to approximate KL?
def kl_divergence(variational_params, x, y):
w0_q, w1_q, w2_q = variational_model(variational_params)
w0_sample = w0_q.sample()
w1_sample = w1_q.sample()
w2_sample = w2_q.sample()
log_joint_prob_value = log_joint_prob(w0_sample, w1_sample, w2_sample, x, y)
log_variational_prob_value = (
w0_q.log_prob(w0_sample) +
w1_q.log_prob(w1_sample) +
w2_q.log_prob(w2_sample)
)
return log_variational_prob_value - log_joint_prob_value
training
Finally, this is basic training code.
optimizer = torch.optim.SGD(params=variational_params.values(), lr=1e-8)
for i in range(9000):
optimizer.zero_grad()
loss_value =kl_divergence(variational_params, x_train, y_train)
loss_value.backward()
optimizer.step()
if (i+1) % 300 == 0 or (i==0):
print(loss_value.detach().numpy())
That loss_value
was unstable. I think the reason is because of KL with one sample.
Can you know what should I modify?