Hi, I’m trying to fit a covariance matrix based on data collected from a relatively simple model using a LKJCholesky prior distribution. However, when I’m doing MAP using the AutoDelta guide the samples generated by the prior are always the identity matrix. I’ve attached my code here as well. Is there something that I’m missing to cause the samples to always be the identity matrix?
Thanks
import pyro
import pyro.distributions as dist
import torch
from pyro.infer import Trace_ELBO, SVI
from pyro.infer.autoguide import AutoDelta
def get_beam_size(k, s):
d = 1.0
L = 1.0
return (1 + d * L * k) ** 2 * s[0] + 2.0 * d * (
1 + d * L * k) * s[1] + (d ** 2) * s[2]
def model(train_k, train_y=None):
# Vector of variances for each of the d variables
theta = pyro.sample(
"theta", dist.LogNormal(2.0*torch.ones(2), 2.0*torch.ones(2)).to_event(1)
)
# distribution over correlation matrices
L_omega = pyro.sample(
"L_omega", dist.LKJCholesky(2, concentration=1.0)
)
# Lower cholesky factor of the covariance matrix
L_Omega = torch.mm(theta.sqrt().diag_embed(), L_omega)
# For inference with SVI, one might prefer to use
# torch.bmm(theta.sqrt().diag_embed(), L_omega)
# Vector of expectations
beam_matrix = L_Omega.T @ L_Omega
s11 = beam_matrix[0, 0]
s12 = beam_matrix[1, 0]
s22 = beam_matrix[1, 1]
mean_y = get_beam_size(train_k, torch.tensor([s11, s12, s22]))
sigma = 10.0 * torch.ones(())
with pyro.plate("data", len(train_k)):
return pyro.sample(
"obs", dist.Normal(mean_y, sigma), obs=train_y
)
def train(model, guide, train_k, train_y, lr=0.001, n_steps=201):
pyro.clear_param_store()
adam = pyro.optim.Adam({"lr": lr})
svi = SVI(model, guide, adam, loss=Trace_ELBO())
for step in range(n_steps):
loss = svi.step(train_k, train_y)
if step % 50 == 0:
print('[iter {}] loss: {:.4f}'.format(step, loss))
def main():
train_k = torch.linspace(-10, 10, 5)
train_y = get_beam_size(train_k, torch.tensor([2.0, 0.9, 2.0]))
guide = AutoDelta(model)
# do map estimation
train(model, guide, train_k, train_y, n_steps=2000)
for name, val in pyro.get_param_store().items():
print(f"{name}:{val}")
if __name__ == '__main__':
main()