I am interested in estimating the covariance matrix of a regression model with a specific covariance prior as outlined in the glasso
model below. However, when I try to estimate the model with NUTS and number of features p above 3, I get ValueError: MultivariateNormal distribution got invalid covariance_matrix parameter.
I believe this is because the correlation matrix does not fulfill the requisites of 1)being non-negative define, 2) having off-diagonal elements between -1 and 1. I would like to know how to go about imposing such constraints.
This is the model:
def glasso(beta0_m=0., beta0_s=1., mu_m=0., mu_s=1., Y=None, n=None, p=None):
try:
n, p = jnp.shape(Y)
except:
assert ((n is None)|(p is None)) is False
with plate("features", p):
mu = sample("mu", dist.Normal(mu_m, mu_s))
sqrt_diag = sample("sqrt_diag", dist.InverseGamma(1., 0.5))
beta0 = sample("beta0", dist.Normal(beta0_m, beta0_s))
off = int((p*p-p)/2)
with plate("off_diag_corr", off):
rho_off = sample("rho_off", dist.Laplace(0, 1/(jnp.exp(beta0))))
rho = jnp.zeros((p,p))
tril_idx = jnp.tril_indices(n=p, k=-1, m=p)
rho = rho.at[tril_idx].set(rho_off)
rho = rho + rho.T + jnp.diag(jnp.ones(p))
theta = jnp.outer(sqrt_diag,sqrt_diag)*rho
theta = deterministic("theta", theta)
with plate("hist", n):
Y = sample("obs", dist.MultivariateNormal(mu,theta), obs = Y)
return {'Y':Y, 'theta':theta, 'mu':mu, 'beta0':beta0}
This is to simulate the data and run the hmc algo:
# simulation params
beta0_true=1
p = 5
mu_true = jnp.ones(p)
n_obs = 100
# estimation params
n_warmup = 10
n_samples = 20
glasso_sub = substitute(glasso, {"mu": mu_true, "beta0":beta0_true})
sim_res = seed(glasso_sub, Key(2*n_obs))(n=n_obs, p=p)
glasso_run = block(condition(glasso, {"beta0":beta0_true}), hide=["beta0"])
nuts_kernel = NUTS(glasso_run)
mcmc = MCMC(nuts_kernel, num_warmup=n_warmup, num_samples=n_samples)
mcmc.run(rng_key = Key(3), Y=sim_res['Y'])
Even though I need to shape the prior as in glasso
, I have experimented as well with a LKJ prior on the correlation matrix (glasso_lkj
, model below), but I do get the same error with number of features p above 18.
def glasso_lkj(beta0_m=0., beta0_s=1., mu_m=0., mu_s=1., Y=None, n=None, p=None):
try:
n, p = jnp.shape(Y)
except:
assert ((n is None)|(p is None)) is False
with plate("features", p):
mu = sample("mu", dist.Normal(mu_m, mu_s))
sqrt_diag = sample("sqrt_diag", dist.InverseGamma(1., 0.5))
rho = sample("rho", dist.LKJ(dimension=p, concentration=1))
beta0 = sample("beta0", dist.Normal(beta0_m, beta0_s))
theta = jnp.outer(sqrt_diag,sqrt_diag)*rho
theta = deterministic("theta", theta)
with plate("hist", n):
Y = sample("obs", dist.MultivariateNormal(mu,theta), obs = Y)
return {'Y':Y, 'theta':theta, 'mu':mu, 'beta0':beta0}
Below the error I get when I try to estimate the model: