I’m looking to fit an eta parameter to estimate a covariance matrix for an MVN. The covariance matrix sampled from the LKJ distribution is not positive definite, both with and without scaling. I’m not sure where to go from here. I read over this older thread] which looked at the same problem, but I wasn’t able to find a fix.
This code has the model, guide, and MAP estimator that I’m using.
def model(self, data):
theta_params = pyro.sample("theta", dist.Delta(torch.ones(self.theta_dim)).to_event(1))
signal_mean = theta_params[:self.signal_dim]
signal_covar_scale = theta_params[self.signal_dim:2*self.signal_dim]
signal_covar_correlation = theta_params[-1]
signal_covar = pyro.sample("signal_covar", dist.LKJ(self.signal_dim, signal_covar_correlation))
signal_covar = torch.diag(torch.abs(signal_covar_scale).sqrt()).mm(signal_covar)
with pyro.plate("data", data.shape[0]):
pyro.sample("obs", dist.MultivariateNormal(signal_mean, covariance_matrix=signal_covar), obs=data)
def guide(self, data):
signal_mean = pyro.param("map_signal_mean")
signal_covar_scale = pyro.param("map_signal_scale")
signal_covar_correlation = pyro.param("map_signal_correlation")
signal_mean = pyro.sample("signal_mean", dist.Delta(signal_mean))
signal_covar_scale = pyro.sample("signal_scale", dist.Delta(signal_covar_scale))
signal_covar_correlation = pyro.sample("signal_correlation", dist.Delta(signal_covar_correlation))
signal_covar = pyro.sample("signal_covar", dist.LKJ(self.signal_dim, signal_covar_correlation))
signal_covar = torch.mm(signal_covar_scale.diag(), torch.mm(signal_covar, signal_covar_scale.diag()))
assert signal_covar is dist.constraints.positive_definite, "PD Failed Pre-scaling"
# signal_covar = (torch.diag(torch.abs(signal_covar_scale).sqrt()).mm(signal_covar)).mm(torch.diag(torch.abs(signal_covar_scale).sqrt()))
with pyro.plate("data", data.shape[0]):
pyro.sample("obs", dist.MultivariateNormal(signal_mean, covariance_matrix=signal_covar))
def update(self, obs, lr_theta=1e-3, steps=25):
pyro.clear_param_store()
signal_mean = self.predicted_theta_mean[:self.signal_dim]
signal_scale = self.predicted_theta_mean[self.signal_dim:2*self.signal_dim]
signal_correlation = self.predicted_theta_mean[-1]
pyro.param("map_signal_mean", signal_mean, dist.constraints.interval(self.signal_lower_bound, self.signal_upper_bound))
pyro.param("map_signal_scale", signal_scale, dist.constraints.positive)
pyro.param("map_signal_correlation", signal_correlation, dist.constraints.positive)
# Observed action is predicted action to nullify update
full_obs = torch.cat([obs, signal_mean[-self.action_dim:]])
opt = Adam({"lr":lr_theta})
svi = SVI(self.model, self.guide, opt, TraceGraph_ELBO())
for i in range(steps):
loss = svi.step(full_obs)
print(f'[Step {i}] loss: {loss}')
The values for the params are initialized in different part of the code.
self.predicted_theta_mean = torch.cat([
torch.zeros(self.signal_dim),
torch.ones(self.signal_dim),
torch.tensor([10], dtype=torch.float)
])
self.predicted_theta_var = torch.cat([
torch.ones(self.theta_dim-1),
torch.tensor([0.1])
])
self.updated_theta_mean = torch.clone(self.predicted_theta_mean)
self.updated_theta_var = torch.clone(self.predicted_theta_var)
self.predicted_signal_mean = torch.zeros(self.signal_dim)
self.updated_signal_mean = torch.zeros(self.signal_dim)
Any suggestions would be greatly appreciated!