Sampling non-PD Covariance Matrix from LKJ

I’m looking to fit an eta parameter to estimate a covariance matrix for an MVN. The covariance matrix sampled from the LKJ distribution is not positive definite, both with and without scaling. I’m not sure where to go from here. I read over this older thread] which looked at the same problem, but I wasn’t able to find a fix.

This code has the model, guide, and MAP estimator that I’m using.

def model(self, data):
        theta_params = pyro.sample("theta", dist.Delta(torch.ones(self.theta_dim)).to_event(1))

        signal_mean = theta_params[:self.signal_dim]
        signal_covar_scale = theta_params[self.signal_dim:2*self.signal_dim]
        signal_covar_correlation = theta_params[-1]

        signal_covar = pyro.sample("signal_covar", dist.LKJ(self.signal_dim, signal_covar_correlation))

        signal_covar = torch.diag(torch.abs(signal_covar_scale).sqrt()).mm(signal_covar)
        
        with pyro.plate("data", data.shape[0]):
            pyro.sample("obs", dist.MultivariateNormal(signal_mean, covariance_matrix=signal_covar), obs=data)


    def guide(self, data):
        signal_mean = pyro.param("map_signal_mean")
        signal_covar_scale = pyro.param("map_signal_scale")
        signal_covar_correlation = pyro.param("map_signal_correlation")


        signal_mean = pyro.sample("signal_mean", dist.Delta(signal_mean))
        signal_covar_scale = pyro.sample("signal_scale", dist.Delta(signal_covar_scale))
        signal_covar_correlation = pyro.sample("signal_correlation", dist.Delta(signal_covar_correlation))

        signal_covar = pyro.sample("signal_covar", dist.LKJ(self.signal_dim, signal_covar_correlation))

        signal_covar = torch.mm(signal_covar_scale.diag(), torch.mm(signal_covar, signal_covar_scale.diag()))
        assert signal_covar is dist.constraints.positive_definite, "PD Failed Pre-scaling"
        # signal_covar = (torch.diag(torch.abs(signal_covar_scale).sqrt()).mm(signal_covar)).mm(torch.diag(torch.abs(signal_covar_scale).sqrt()))
        
        with pyro.plate("data", data.shape[0]):
            pyro.sample("obs", dist.MultivariateNormal(signal_mean, covariance_matrix=signal_covar))

    def update(self, obs, lr_theta=1e-3, steps=25):
        pyro.clear_param_store()

        signal_mean = self.predicted_theta_mean[:self.signal_dim]
        signal_scale = self.predicted_theta_mean[self.signal_dim:2*self.signal_dim]
        signal_correlation = self.predicted_theta_mean[-1]

        pyro.param("map_signal_mean", signal_mean, dist.constraints.interval(self.signal_lower_bound, self.signal_upper_bound))
        pyro.param("map_signal_scale", signal_scale, dist.constraints.positive)
        pyro.param("map_signal_correlation", signal_correlation, dist.constraints.positive)

        # Observed action is predicted action to nullify update
        full_obs = torch.cat([obs, signal_mean[-self.action_dim:]])

        opt = Adam({"lr":lr_theta})
        svi = SVI(self.model, self.guide, opt, TraceGraph_ELBO())
        for i in range(steps):
            loss = svi.step(full_obs)
            print(f'[Step {i}] loss: {loss}')

The values for the params are initialized in different part of the code.

self.predicted_theta_mean = torch.cat([
            torch.zeros(self.signal_dim),
            torch.ones(self.signal_dim),
            torch.tensor([10], dtype=torch.float)
        ])
        self.predicted_theta_var = torch.cat([
            torch.ones(self.theta_dim-1),
            torch.tensor([0.1])
        ])
        self.updated_theta_mean = torch.clone(self.predicted_theta_mean)
        self.updated_theta_var = torch.clone(self.predicted_theta_var)

        self.predicted_signal_mean = torch.zeros(self.signal_dim)
        self.updated_signal_mean = torch.zeros(self.signal_dim)

Any suggestions would be greatly appreciated!

i wouldn’t generally expect variational inference to work all that well in trying to infer a distribution over covariance matrices. MAP might be more appropriate